游雁
2023-02-14 1d4ab65c8bfebaecbcb0eec0064bae9a321cad75
funasr/tasks/abs_task.py
@@ -25,6 +25,7 @@
import humanfriendly
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing
import torch.nn
import torch.optim
@@ -38,17 +39,19 @@
from funasr.datasets.dataset import DATA_TYPES
from funasr.datasets.dataset import ESPnetDataset
from funasr.datasets.iterable_dataset import IterableESPnetDataset
from funasr.datasets.iterable_dataset_modelscope import IterableESPnetDatasetModelScope, IterableESPnetBytesModelScope
from funasr.iterators.abs_iter_factory import AbsIterFactory
from funasr.iterators.chunk_iter_factory import ChunkIterFactory
from funasr.iterators.multiple_iter_factory import MultipleIterFactory
from funasr.iterators.sequence_iter_factory import SequenceIterFactory
from funasr.main_funcs.collect_stats import collect_stats
from funasr.optimizers.sgd import SGD
from funasr.optimizers.fairseq_adam import FairseqAdam
from funasr.samplers.build_batch_sampler import BATCH_TYPES
from funasr.samplers.build_batch_sampler import build_batch_sampler
from funasr.samplers.unsorted_batch_sampler import UnsortedBatchSampler
from funasr.schedulers.noam_lr import NoamLR
from funasr.schedulers.warmup_lr import WarmupLR
from funasr.schedulers.tri_stage_scheduler import TriStageLR
from funasr.torch_utils.load_pretrained_model import load_pretrained_model
from funasr.torch_utils.model_summary import model_summary
from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
@@ -68,6 +71,7 @@
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_int
from funasr.utils.types import str_or_none
from funasr.utils.wav_utils import calc_shape, generate_data_list, filter_wav_text
from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
try:
@@ -82,6 +86,7 @@
optim_classes = dict(
    adam=torch.optim.Adam,
    fairseq_adam=FairseqAdam,
    adamw=torch.optim.AdamW,
    sgd=SGD,
    adadelta=torch.optim.Adadelta,
@@ -148,6 +153,7 @@
    CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
    noamlr=NoamLR,
    warmuplr=WarmupLR,
    tri_stage=TriStageLR,
    cycliclr=torch.optim.lr_scheduler.CyclicLR,
    onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
    CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
@@ -182,6 +188,7 @@
    num_optimizers: int = 1
    trainer = Trainer
    class_choices_list: List[ClassChoices] = []
    finetune_args: None
    def __init__(self):
        raise RuntimeError("This class can't be instantiated.")
@@ -279,7 +286,7 @@
        # NOTE(kamo): add_arguments(..., required=True) can't be used
        #  to provide --print_config mode. Instead of it, do as
        parser.set_defaults(required=["output_dir"])
        # parser.set_defaults(required=["output_dir"])
        group = parser.add_argument_group("Common configuration")
@@ -696,7 +703,7 @@
        group.add_argument(
            "--batch_type",
            type=str,
            default="folded",
            default="length",
            choices=list(BATCH_TYPES),
            help=_batch_type_help,
        )
@@ -706,6 +713,18 @@
            default=None,
            choices=list(BATCH_TYPES) + [None],
            help="If not given, the value of --batch_type is used",
        )
        group.add_argument(
            "--speech_length_min",
            type=int,
            default=-1,
            help="speech length min",
        )
        group.add_argument(
            "--speech_length_max",
            type=int,
            default=-1,
            help="speech length max",
        )
        group.add_argument("--fold_length", type=int, action="append", default=[])
        group.add_argument(
@@ -878,6 +897,11 @@
            help="flag to indicate whether training on PAI",
        )
        group.add_argument(
            "--simple_ddp",
            type=str2bool,
            default=False,
        )
        group.add_argument(
            "--num_worker_count",
            type=int,
            default=1,
@@ -1005,29 +1029,30 @@
    @classmethod
    def check_required_command_args(cls, args: argparse.Namespace):
        assert check_argument_types()
        for k in vars(args):
            if "-" in k:
                raise RuntimeError(f'Use "_" instead of "-": parser.get_parser("{k}")')
        if hasattr(args, "required"):
            for k in vars(args):
                if "-" in k:
                    raise RuntimeError(f'Use "_" instead of "-": parser.get_parser("{k}")')
        required = ", ".join(
            f"--{a}" for a in args.required if getattr(args, a) is None
        )
        if len(required) != 0:
            parser = cls.get_parser()
            parser.print_help(file=sys.stderr)
            p = Path(sys.argv[0]).name
            print(file=sys.stderr)
            print(
                f"{p}: error: the following arguments are required: " f"{required}",
                file=sys.stderr,
            required = ", ".join(
                f"--{a}" for a in args.required if getattr(args, a) is None
            )
            sys.exit(2)
            if len(required) != 0:
                parser = cls.get_parser()
                parser.print_help(file=sys.stderr)
                p = Path(sys.argv[0]).name
                print(file=sys.stderr)
                print(
                    f"{p}: error: the following arguments are required: " f"{required}",
                    file=sys.stderr,
                )
                sys.exit(2)
    @classmethod
    def check_task_requirements(
            cls,
            dataset: Union[AbsDataset, IterableESPnetDataset, IterableESPnetDatasetModelScope, IterableESPnetBytesModelScope],
            dataset: Union[AbsDataset, IterableESPnetDataset],
            allow_variable_data_keys: bool,
            train: bool,
            inference: bool = False,
@@ -1087,6 +1112,22 @@
            cls.main_worker(args)
    @classmethod
    def run(cls):
        assert hasattr(cls, "finetune_args")
        args = cls.finetune_args
        args.train_shape_file = None
        if args.distributed:
            args.simple_ddp = True
        else:
            args.simple_ddp = False
            args.ngpu = 1
        args.use_pai = False
        args.batch_type = "length"
        args.oss_bucket = None
        args.input_size = None
        cls.main_worker(args)
    @classmethod
    def main_worker(cls, args: argparse.Namespace):
        assert check_argument_types()
@@ -1095,8 +1136,48 @@
        # Setting distributed_option.dist_rank, etc.
        if args.use_pai:
            distributed_option.init_options_pai()
        else:
        elif not args.simple_ddp:
            distributed_option.init_options()
        # Invoking torch.distributed.init_process_group
        if args.use_pai:
            distributed_option.init_torch_distributed_pai(args)
        elif not args.simple_ddp:
            distributed_option.init_torch_distributed(args)
        elif args.distributed and args.simple_ddp:
            distributed_option.init_torch_distributed_pai(args)
            args.ngpu = dist.get_world_size()
            if args.dataset_type == "small":
                if args.batch_size is not None:
                    args.batch_size = args.batch_size * args.ngpu
                if args.batch_bins is not None:
                    args.batch_bins = args.batch_bins * args.ngpu
        # filter samples if wav.scp and text are mismatch
        if (args.train_shape_file is None and args.dataset_type == "small") or args.train_data_file is None and args.dataset_type == "large":
            if not args.simple_ddp or distributed_option.dist_rank == 0:
                filter_wav_text(args.data_dir, args.train_set)
                filter_wav_text(args.data_dir, args.dev_set)
            if args.simple_ddp:
                dist.barrier()
        if args.train_shape_file is None and args.dataset_type == "small":
            if not args.simple_ddp or distributed_option.dist_rank == 0:
                calc_shape(args.data_dir, args.train_set, args.frontend_conf, args.speech_length_min, args.speech_length_max)
                calc_shape(args.data_dir, args.dev_set, args.frontend_conf, args.speech_length_min, args.speech_length_max)
            if args.simple_ddp:
                dist.barrier()
            args.train_shape_file = [os.path.join(args.data_dir, args.train_set, "speech_shape")]
            args.valid_shape_file = [os.path.join(args.data_dir, args.dev_set, "speech_shape")]
        if args.train_data_file is None and args.dataset_type == "large":
            if not args.simple_ddp or distributed_option.dist_rank == 0:
                generate_data_list(args.data_dir, args.train_set)
                generate_data_list(args.data_dir, args.dev_set)
            if args.simple_ddp:
                dist.barrier()
            args.train_data_file = os.path.join(args.data_dir, args.train_set, "data.list")
            args.valid_data_file = os.path.join(args.data_dir, args.dev_set, "data.list")
        # NOTE(kamo): Don't use logging before invoking logging.basicConfig()
        if not distributed_option.distributed or distributed_option.dist_rank == 0:
@@ -1124,11 +1205,9 @@
                format=f"[{os.uname()[1].split('.')[0]}]"
                       f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
            )
        # Invoking torch.distributed.init_process_group
        if args.use_pai:
            distributed_option.init_torch_distributed_pai(args)
        else:
            distributed_option.init_torch_distributed(args)
        logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
                                                                       distributed_option.dist_rank,
                                                                       distributed_option.local_rank))
        # 1. Set random-seed
        set_all_random_seed(args.seed)
@@ -1202,6 +1281,52 @@
        if args.dry_run:
            pass
        elif args.collect_stats:
            # Perform on collect_stats mode. This mode has two roles
            # - Derive the length and dimension of all input data
            # - Accumulate feats, square values, and the length for whitening
            if args.valid_batch_size is None:
                args.valid_batch_size = args.batch_size
            if len(args.train_shape_file) != 0:
                train_key_file = args.train_shape_file[0]
            else:
                train_key_file = None
            if len(args.valid_shape_file) != 0:
                valid_key_file = args.valid_shape_file[0]
            else:
                valid_key_file = None
            collect_stats(
                model=model,
                train_iter=cls.build_streaming_iterator(
                    data_path_and_name_and_type=args.train_data_path_and_name_and_type,
                    key_file=train_key_file,
                    batch_size=args.batch_size,
                    dtype=args.train_dtype,
                    num_workers=args.num_workers,
                    allow_variable_data_keys=args.allow_variable_data_keys,
                    ngpu=args.ngpu,
                    preprocess_fn=cls.build_preprocess_fn(args, train=False),
                    collate_fn=cls.build_collate_fn(args, train=False),
                ),
                valid_iter=cls.build_streaming_iterator(
                    data_path_and_name_and_type=args.valid_data_path_and_name_and_type,
                    key_file=valid_key_file,
                    batch_size=args.valid_batch_size,
                    dtype=args.train_dtype,
                    num_workers=args.num_workers,
                    allow_variable_data_keys=args.allow_variable_data_keys,
                    ngpu=args.ngpu,
                    preprocess_fn=cls.build_preprocess_fn(args, train=False),
                    collate_fn=cls.build_collate_fn(args, train=False),
                ),
                output_dir=output_dir,
                ngpu=args.ngpu,
                log_interval=args.log_interval,
                write_collected_feats=args.write_collected_feats,
            )
        else:
            logging.info("Training args: {}".format(args))
            # 6. Loads pre-trained model
@@ -1222,10 +1347,14 @@
            # 7. Build iterator factories
            if args.dataset_type == "large":
                from funasr.datasets.large_datasets.build_dataloader import ArkDataLoader
                train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list,
                                                   args.config, mode="train")
                valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list,
                                                   args.config, mode="eval")
                train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list, args.dataset_conf,
                                                   seg_dict_file=args.seg_dict_file if hasattr(args,
                                                                                               "seg_dict_file") else None,
                                                   mode="train")
                valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf,
                                                   seg_dict_file=args.seg_dict_file if hasattr(args,
                                                                                               "seg_dict_file") else None,
                                                   mode="eval")
            elif args.dataset_type == "small":
                train_iter_factory = cls.build_iter_factory(
                    args=args,
@@ -1713,6 +1842,7 @@
            collate_fn,
            key_file: str = None,
            batch_size: int = 1,
            fs: dict = None,
            dtype: str = np.float32,
            num_workers: int = 1,
            allow_variable_data_keys: bool = False,
@@ -1730,6 +1860,7 @@
        dataset = IterableESPnetDataset(
            data_path_and_name_and_type,
            float_dtype=dtype,
            fs=fs,
            preprocess=preprocess_fn,
            key_file=key_file,
        )
@@ -1749,70 +1880,13 @@
            **kwargs,
        )
    @classmethod
    def build_streaming_iterator_modelscope(
            cls,
            data_path_and_name_and_type,
            preprocess_fn,
            collate_fn,
            key_file: str = None,
            batch_size: int = 1,
            dtype: str = np.float32,
            num_workers: int = 1,
            allow_variable_data_keys: bool = False,
            ngpu: int = 0,
            inference: bool = False,
            sample_rate: Union[dict, int] = 16000
    ) -> DataLoader:
        """Build DataLoader using iterable dataset"""
        assert check_argument_types()
        # For backward compatibility for pytorch DataLoader
        if collate_fn is not None:
            kwargs = dict(collate_fn=collate_fn)
        else:
            kwargs = {}
        audio_data = data_path_and_name_and_type[0]
        if isinstance(audio_data, bytes):
            dataset = IterableESPnetBytesModelScope(
                data_path_and_name_and_type,
                float_dtype=dtype,
                preprocess=preprocess_fn,
                key_file=key_file,
                sample_rate=sample_rate
            )
        else:
            dataset = IterableESPnetDatasetModelScope(
                data_path_and_name_and_type,
                float_dtype=dtype,
                preprocess=preprocess_fn,
                key_file=key_file,
                sample_rate=sample_rate
            )
        if dataset.apply_utt2category:
            kwargs.update(batch_size=1)
        else:
            kwargs.update(batch_size=batch_size)
        cls.check_task_requirements(dataset,
                                    allow_variable_data_keys,
                                    train=False,
                                    inference=inference)
        return DataLoader(
            dataset=dataset,
            pin_memory=ngpu > 0,
            num_workers=num_workers,
            **kwargs,
        )
    # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
    @classmethod
    def build_model_from_file(
            cls,
            config_file: Union[Path, str] = None,
            model_file: Union[Path, str] = None,
            cmvn_file: Union[Path, str] = None,
            device: str = "cpu",
    ) -> Tuple[AbsESPnetModel, argparse.Namespace]:
        """Build model from the files.
@@ -1837,6 +1911,8 @@
        with config_file.open("r", encoding="utf-8") as f:
            args = yaml.safe_load(f)
        if cmvn_file is not None:
            args["cmvn_file"] = cmvn_file
        args = argparse.Namespace(**args)
        model = cls.build_model(args)
        if not isinstance(model, AbsESPnetModel):
@@ -1850,5 +1926,5 @@
                #   in PyTorch<=1.4
                device = f"cuda:{torch.cuda.current_device()}"
            model.load_state_dict(torch.load(model_file, map_location=device))
        model.to(device)
        return model, args