嘉渊
2023-04-24 6427c834dfd97b1f05c6659cdc7ccf010bf82fe1
funasr/tasks/abs_task.py
@@ -25,10 +25,12 @@
import humanfriendly
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing
import torch.nn
import torch.optim
import yaml
from funasr.train.abs_espnet_model import AbsESPnetModel
from torch.utils.data import DataLoader
from typeguard import check_argument_types
from typeguard import check_return_type
@@ -38,22 +40,23 @@
from funasr.datasets.dataset import DATA_TYPES
from funasr.datasets.dataset import ESPnetDataset
from funasr.datasets.iterable_dataset import IterableESPnetDataset
from funasr.datasets.iterable_dataset_modelscope import IterableESPnetDatasetModelScope, IterableESPnetBytesModelScope
from funasr.iterators.abs_iter_factory import AbsIterFactory
from funasr.iterators.chunk_iter_factory import ChunkIterFactory
from funasr.iterators.multiple_iter_factory import MultipleIterFactory
from funasr.iterators.sequence_iter_factory import SequenceIterFactory
from funasr.main_funcs.collect_stats import collect_stats
from funasr.optimizers.fairseq_adam import FairseqAdam
from funasr.optimizers.sgd import SGD
from funasr.samplers.build_batch_sampler import BATCH_TYPES
from funasr.samplers.build_batch_sampler import build_batch_sampler
from funasr.samplers.unsorted_batch_sampler import UnsortedBatchSampler
from funasr.schedulers.noam_lr import NoamLR
from funasr.schedulers.tri_stage_scheduler import TriStageLR
from funasr.schedulers.warmup_lr import WarmupLR
from funasr.torch_utils.load_pretrained_model import load_pretrained_model
from funasr.torch_utils.model_summary import model_summary
from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.train.class_choices import ClassChoices
from funasr.train.distributed_utils import DistributedOption
from funasr.train.trainer import Trainer
@@ -68,6 +71,7 @@
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_int
from funasr.utils.types import str_or_none
from funasr.utils.wav_utils import calc_shape, generate_data_list, filter_wav_text
from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
try:
@@ -82,6 +86,7 @@
optim_classes = dict(
    adam=torch.optim.Adam,
    fairseq_adam=FairseqAdam,
    adamw=torch.optim.AdamW,
    sgd=SGD,
    adadelta=torch.optim.Adadelta,
@@ -148,6 +153,7 @@
    CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
    noamlr=NoamLR,
    warmuplr=WarmupLR,
    tri_stage=TriStageLR,
    cycliclr=torch.optim.lr_scheduler.CyclicLR,
    onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
    CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
@@ -182,6 +188,7 @@
    num_optimizers: int = 1
    trainer = Trainer
    class_choices_list: List[ClassChoices] = []
    finetune_args: None
    def __init__(self):
        raise RuntimeError("This class can't be instantiated.")
@@ -279,7 +286,7 @@
        # NOTE(kamo): add_arguments(..., required=True) can't be used
        #  to provide --print_config mode. Instead of it, do as
        parser.set_defaults(required=["output_dir"])
        # parser.set_defaults(required=["output_dir"])
        group = parser.add_argument_group("Common configuration")
@@ -457,6 +464,12 @@
            default=sys.maxsize,
            help="The maximum number update step to train",
        )
        parser.add_argument(
            "--batch_interval",
            type=int,
            default=10000,
            help="The batch interval for saving model.",
        )
        group.add_argument(
            "--patience",
            type=int_or_none,
@@ -632,12 +645,12 @@
                 "and exclude_keys excludes keys of model states for the initialization."
                 "e.g.\n"
                 "  # Load all parameters"
                 "  --init_param some/where/model.pth\n"
                 "  --init_param some/where/model.pb\n"
                 "  # Load only decoder parameters"
                 "  --init_param some/where/model.pth:decoder:decoder\n"
                 "  --init_param some/where/model.pb:decoder:decoder\n"
                 "  # Load only decoder parameters excluding decoder.embed"
                 "  --init_param some/where/model.pth:decoder:decoder:decoder.embed\n"
                 "  --init_param some/where/model.pth:decoder:decoder:decoder.embed\n",
                 "  --init_param some/where/model.pb:decoder:decoder:decoder.embed\n"
                 "  --init_param some/where/model.pb:decoder:decoder:decoder.embed\n",
        )
        group.add_argument(
            "--ignore_init_mismatch",
@@ -696,7 +709,7 @@
        group.add_argument(
            "--batch_type",
            type=str,
            default="folded",
            default="length",
            choices=list(BATCH_TYPES),
            help=_batch_type_help,
        )
@@ -706,6 +719,18 @@
            default=None,
            choices=list(BATCH_TYPES) + [None],
            help="If not given, the value of --batch_type is used",
        )
        group.add_argument(
            "--speech_length_min",
            type=int,
            default=-1,
            help="speech length min",
        )
        group.add_argument(
            "--speech_length_max",
            type=int,
            default=-1,
            help="speech length max",
        )
        group.add_argument("--fold_length", type=int, action="append", default=[])
        group.add_argument(
@@ -878,6 +903,11 @@
            help="flag to indicate whether training on PAI",
        )
        group.add_argument(
            "--simple_ddp",
            type=str2bool,
            default=False,
        )
        group.add_argument(
            "--num_worker_count",
            type=int,
            default=1,
@@ -1005,29 +1035,30 @@
    @classmethod
    def check_required_command_args(cls, args: argparse.Namespace):
        assert check_argument_types()
        for k in vars(args):
            if "-" in k:
                raise RuntimeError(f'Use "_" instead of "-": parser.get_parser("{k}")')
        if hasattr(args, "required"):
            for k in vars(args):
                if "-" in k:
                    raise RuntimeError(f'Use "_" instead of "-": parser.get_parser("{k}")')
        required = ", ".join(
            f"--{a}" for a in args.required if getattr(args, a) is None
        )
        if len(required) != 0:
            parser = cls.get_parser()
            parser.print_help(file=sys.stderr)
            p = Path(sys.argv[0]).name
            print(file=sys.stderr)
            print(
                f"{p}: error: the following arguments are required: " f"{required}",
                file=sys.stderr,
            required = ", ".join(
                f"--{a}" for a in args.required if getattr(args, a) is None
            )
            sys.exit(2)
            if len(required) != 0:
                parser = cls.get_parser()
                parser.print_help(file=sys.stderr)
                p = Path(sys.argv[0]).name
                print(file=sys.stderr)
                print(
                    f"{p}: error: the following arguments are required: " f"{required}",
                    file=sys.stderr,
                )
                sys.exit(2)
    @classmethod
    def check_task_requirements(
            cls,
            dataset: Union[AbsDataset, IterableESPnetDataset, IterableESPnetDatasetModelScope, IterableESPnetBytesModelScope],
            dataset: Union[AbsDataset, IterableESPnetDataset],
            allow_variable_data_keys: bool,
            train: bool,
            inference: bool = False,
@@ -1087,6 +1118,22 @@
            cls.main_worker(args)
    @classmethod
    def run(cls):
        assert hasattr(cls, "finetune_args")
        args = cls.finetune_args
        args.train_shape_file = None
        if args.distributed:
            args.simple_ddp = True
        else:
            args.simple_ddp = False
            args.ngpu = 1
        args.use_pai = False
        args.batch_type = "length"
        args.oss_bucket = None
        args.input_size = None
        cls.main_worker(args)
    @classmethod
    def main_worker(cls, args: argparse.Namespace):
        assert check_argument_types()
@@ -1095,8 +1142,51 @@
        # Setting distributed_option.dist_rank, etc.
        if args.use_pai:
            distributed_option.init_options_pai()
        else:
        elif not args.simple_ddp:
            distributed_option.init_options()
        # Invoking torch.distributed.init_process_group
        if args.use_pai:
            distributed_option.init_torch_distributed_pai(args)
        elif not args.simple_ddp:
            distributed_option.init_torch_distributed(args)
        elif args.distributed and args.simple_ddp:
            distributed_option.init_torch_distributed_pai(args)
            args.ngpu = dist.get_world_size()
            if args.dataset_type == "small":
                if args.batch_size is not None:
                    args.batch_size = args.batch_size * args.ngpu
                if args.batch_bins is not None:
                    args.batch_bins = args.batch_bins * args.ngpu
        # filter samples if wav.scp and text are mismatch
        if (
                args.train_shape_file is None and args.dataset_type == "small") or args.train_data_file is None and args.dataset_type == "large":
            if not args.simple_ddp or distributed_option.dist_rank == 0:
                filter_wav_text(args.data_dir, args.train_set)
                filter_wav_text(args.data_dir, args.dev_set)
            if args.simple_ddp:
                dist.barrier()
        if args.train_shape_file is None and args.dataset_type == "small":
            if not args.simple_ddp or distributed_option.dist_rank == 0:
                calc_shape(args.data_dir, args.train_set, args.frontend_conf, args.speech_length_min,
                           args.speech_length_max)
                calc_shape(args.data_dir, args.dev_set, args.frontend_conf, args.speech_length_min,
                           args.speech_length_max)
            if args.simple_ddp:
                dist.barrier()
            args.train_shape_file = [os.path.join(args.data_dir, args.train_set, "speech_shape")]
            args.valid_shape_file = [os.path.join(args.data_dir, args.dev_set, "speech_shape")]
        if args.train_data_file is None and args.dataset_type == "large":
            if not args.simple_ddp or distributed_option.dist_rank == 0:
                generate_data_list(args.data_dir, args.train_set)
                generate_data_list(args.data_dir, args.dev_set)
            if args.simple_ddp:
                dist.barrier()
            args.train_data_file = os.path.join(args.data_dir, args.train_set, "data.list")
            args.valid_data_file = os.path.join(args.data_dir, args.dev_set, "data.list")
        # NOTE(kamo): Don't use logging before invoking logging.basicConfig()
        if not distributed_option.distributed or distributed_option.dist_rank == 0:
@@ -1112,23 +1202,27 @@
            # logging.basicConfig() is invoked in main_worker() instead of main()
            # because it can be invoked only once in a process.
            # FIXME(kamo): Should we use logging.getLogger()?
            # BUGFIX: Remove previous handlers and reset log level
            for handler in logging.root.handlers[:]:
                logging.root.removeHandler(handler)
            logging.basicConfig(
                level=args.log_level,
                format=f"[{os.uname()[1].split('.')[0]}]"
                       f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
            )
        else:
            # BUGFIX: Remove previous handlers and reset log level
            for handler in logging.root.handlers[:]:
                logging.root.removeHandler(handler)
            # Suppress logging if RANK != 0
            logging.basicConfig(
                level="ERROR",
                format=f"[{os.uname()[1].split('.')[0]}]"
                       f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
            )
        # Invoking torch.distributed.init_process_group
        if args.use_pai:
            distributed_option.init_torch_distributed_pai(args)
        else:
            distributed_option.init_torch_distributed(args)
        logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
                                                                       distributed_option.dist_rank,
                                                                       distributed_option.local_rank))
        # 1. Set random-seed
        set_all_random_seed(args.seed)
@@ -1202,6 +1296,52 @@
        if args.dry_run:
            pass
        elif args.collect_stats:
            # Perform on collect_stats mode. This mode has two roles
            # - Derive the length and dimension of all input data
            # - Accumulate feats, square values, and the length for whitening
            if args.valid_batch_size is None:
                args.valid_batch_size = args.batch_size
            if len(args.train_shape_file) != 0:
                train_key_file = args.train_shape_file[0]
            else:
                train_key_file = None
            if len(args.valid_shape_file) != 0:
                valid_key_file = args.valid_shape_file[0]
            else:
                valid_key_file = None
            collect_stats(
                model=model,
                train_iter=cls.build_streaming_iterator(
                    data_path_and_name_and_type=args.train_data_path_and_name_and_type,
                    key_file=train_key_file,
                    batch_size=args.batch_size,
                    dtype=args.train_dtype,
                    num_workers=args.num_workers,
                    allow_variable_data_keys=args.allow_variable_data_keys,
                    ngpu=args.ngpu,
                    preprocess_fn=cls.build_preprocess_fn(args, train=False),
                    collate_fn=cls.build_collate_fn(args, train=False),
                ),
                valid_iter=cls.build_streaming_iterator(
                    data_path_and_name_and_type=args.valid_data_path_and_name_and_type,
                    key_file=valid_key_file,
                    batch_size=args.valid_batch_size,
                    dtype=args.train_dtype,
                    num_workers=args.num_workers,
                    allow_variable_data_keys=args.allow_variable_data_keys,
                    ngpu=args.ngpu,
                    preprocess_fn=cls.build_preprocess_fn(args, train=False),
                    collate_fn=cls.build_collate_fn(args, train=False),
                ),
                output_dir=output_dir,
                ngpu=args.ngpu,
                log_interval=args.log_interval,
                write_collected_feats=args.write_collected_feats,
            )
        else:
            logging.info("Training args: {}".format(args))
            # 6. Loads pre-trained model
@@ -1222,10 +1362,24 @@
            # 7. Build iterator factories
            if args.dataset_type == "large":
                from funasr.datasets.large_datasets.build_dataloader import ArkDataLoader
                train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list,
                                                   args.config, mode="train")
                valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list,
                                                   args.config, mode="eval")
                train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list, args.dataset_conf,
                                                   frontend_conf=args.frontend_conf if hasattr(args,
                                                                                               "frontend_conf") else None,
                                                   seg_dict_file=args.seg_dict_file if hasattr(args,
                                                                                               "seg_dict_file") else None,
                                                   punc_dict_file=args.punc_list if hasattr(args,
                                                                                            "punc_list") else None,
                                                   bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None,
                                                   mode="train")
                valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf,
                                                   frontend_conf=args.frontend_conf if hasattr(args,
                                                                                               "frontend_conf") else None,
                                                   seg_dict_file=args.seg_dict_file if hasattr(args,
                                                                                               "seg_dict_file") else None,
                                                   punc_dict_file=args.punc_list if hasattr(args,
                                                                                            "punc_list") else None,
                                                   bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None,
                                                   mode="eval")
            elif args.dataset_type == "small":
                train_iter_factory = cls.build_iter_factory(
                    args=args,
@@ -1437,12 +1591,18 @@
    ) -> AbsIterFactory:
        assert check_argument_types()
        if args.frontend_conf is not None and "fs" in args.frontend_conf:
            dest_sample_rate = args.frontend_conf["fs"]
        else:
            dest_sample_rate = 16000
        dataset = ESPnetDataset(
            iter_options.data_path_and_name_and_type,
            float_dtype=args.train_dtype,
            preprocess=iter_options.preprocess_fn,
            max_cache_size=iter_options.max_cache_size,
            max_cache_fd=iter_options.max_cache_fd,
            dest_sample_rate=dest_sample_rate,
        )
        cls.check_task_requirements(
            dataset, args.allow_variable_data_keys, train=iter_options.train
@@ -1713,6 +1873,8 @@
            collate_fn,
            key_file: str = None,
            batch_size: int = 1,
            fs: dict = None,
            mc: bool = False,
            dtype: str = np.float32,
            num_workers: int = 1,
            allow_variable_data_keys: bool = False,
@@ -1730,6 +1892,8 @@
        dataset = IterableESPnetDataset(
            data_path_and_name_and_type,
            float_dtype=dtype,
            fs=fs,
            mc=mc,
            preprocess=preprocess_fn,
            key_file=key_file,
        )
@@ -1749,70 +1913,13 @@
            **kwargs,
        )
    @classmethod
    def build_streaming_iterator_modelscope(
            cls,
            data_path_and_name_and_type,
            preprocess_fn,
            collate_fn,
            key_file: str = None,
            batch_size: int = 1,
            dtype: str = np.float32,
            num_workers: int = 1,
            allow_variable_data_keys: bool = False,
            ngpu: int = 0,
            inference: bool = False,
            sample_rate: Union[dict, int] = 16000
    ) -> DataLoader:
        """Build DataLoader using iterable dataset"""
        assert check_argument_types()
        # For backward compatibility for pytorch DataLoader
        if collate_fn is not None:
            kwargs = dict(collate_fn=collate_fn)
        else:
            kwargs = {}
        audio_data = data_path_and_name_and_type[0]
        if isinstance(audio_data, bytes):
            dataset = IterableESPnetBytesModelScope(
                data_path_and_name_and_type,
                float_dtype=dtype,
                preprocess=preprocess_fn,
                key_file=key_file,
                sample_rate=sample_rate
            )
        else:
            dataset = IterableESPnetDatasetModelScope(
                data_path_and_name_and_type,
                float_dtype=dtype,
                preprocess=preprocess_fn,
                key_file=key_file,
                sample_rate=sample_rate
            )
        if dataset.apply_utt2category:
            kwargs.update(batch_size=1)
        else:
            kwargs.update(batch_size=batch_size)
        cls.check_task_requirements(dataset,
                                    allow_variable_data_keys,
                                    train=False,
                                    inference=inference)
        return DataLoader(
            dataset=dataset,
            pin_memory=ngpu > 0,
            num_workers=num_workers,
            **kwargs,
        )
    # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
    @classmethod
    def build_model_from_file(
            cls,
            config_file: Union[Path, str] = None,
            model_file: Union[Path, str] = None,
            cmvn_file: Union[Path, str] = None,
            device: str = "cpu",
    ) -> Tuple[AbsESPnetModel, argparse.Namespace]:
        """Build model from the files.
@@ -1837,6 +1944,8 @@
        with config_file.open("r", encoding="utf-8") as f:
            args = yaml.safe_load(f)
        if cmvn_file is not None:
            args["cmvn_file"] = cmvn_file
        args = argparse.Namespace(**args)
        model = cls.build_model(args)
        if not isinstance(model, AbsESPnetModel):
@@ -1850,5 +1959,5 @@
                #   in PyTorch<=1.4
                device = f"cuda:{torch.cuda.current_device()}"
            model.load_state_dict(torch.load(model_file, map_location=device))
        model.to(device)
        return model, args