| | |
| | | infer_cmd=utils/run.pl |
| | | |
| | | # general configuration |
| | | feats_dir="/nfs/wangjiaming.wjm/Funasr_data/aishell-1-fix-cmvn" #feature output dictionary |
| | | feats_dir="/nfs/wangjiaming.wjm/Funasr_data_test/aishell" #feature output dictionary |
| | | exp_dir="." |
| | | lang=zh |
| | | dumpdir=dump/fbank |
| | |
| | | --use_preprocessor true \ |
| | | --token_type char \ |
| | | --token_list $token_list \ |
| | | --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \ |
| | | --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \ |
| | | --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \ |
| | | --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \ |
| | | --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \ |
| | | --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \ |
| | | --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \ |
| | | --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char \ |
| | | --data_dir ${feats_dir}/data \ |
| | | --train_set ${train_set} \ |
| | | --valid_set ${valid_set} \ |
| | | --cmvn_file ${feats_dir}/cmvn/cmvn.mvn \ |
| | | --resume true \ |
| | | --output_dir ${exp_dir}/exp/${model_dir} \ |
| | | --config $asr_config \ |
| | |
| | | from funasr.utils.prepare_data import prepare_data |
| | | from funasr.utils.types import int_or_none |
| | | from funasr.utils.types import str2bool |
| | | from funasr.utils.types import str2triple_str |
| | | from funasr.utils.types import str_or_none |
| | | from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump |
| | | |
| | |
| | | help=f"The keyword arguments for dataset", |
| | | ) |
| | | parser.add_argument( |
| | | "--train_data_file", |
| | | "--data_dir", |
| | | type=str, |
| | | default=None, |
| | | help="train_list for large dataset", |
| | | help="root path of data", |
| | | ) |
| | | parser.add_argument( |
| | | "--valid_data_file", |
| | | "--train_set", |
| | | type=str, |
| | | default=None, |
| | | help="valid_list for large dataset", |
| | | default="train", |
| | | help="train dataset", |
| | | ) |
| | | parser.add_argument( |
| | | "--train_data_path_and_name_and_type", |
| | | type=str2triple_str, |
| | | action="append", |
| | | default=[], |
| | | help="e.g. '--train_data_path_and_name_and_type some/path/a.scp,foo,sound'. ", |
| | | ) |
| | | parser.add_argument( |
| | | "--valid_data_path_and_name_and_type", |
| | | type=str2triple_str, |
| | | action="append", |
| | | default=[], |
| | | ) |
| | | parser.add_argument( |
| | | "--train_shape_file", |
| | | "--valid_set", |
| | | type=str, |
| | | action="append", |
| | | default=[], |
| | | default="validation", |
| | | help="dev dataset", |
| | | ) |
| | | parser.add_argument( |
| | | "--valid_shape_file", |
| | | type=str, |
| | | action="append", |
| | | default=[], |
| | | ) |
| | | |
| | | parser.add_argument( |
| | | "--use_preprocessor", |
| | | type=str2bool, |
| | |
| | | f_text.write(sample_name + " " + text_dict[sample_name] + "\n") |
| | | else: |
| | | filter_count += 1 |
| | | logging.info( |
| | | "{}/{} samples in {} are filtered because of the mismatch between wav.scp and text".format(len(wav_lines), |
| | | filter_count, |
| | | dataset)) |
| | | logging.info("{}/{} samples in {} are filtered because of the mismatch between wav.scp and text". |
| | | format(filter_count, len(wav_lines), dataset)) |
| | | |
| | | |
| | | def wav2num_frame(wav_path, frontend_conf): |
| | |
| | | |
| | | |
| | | def prepare_data(args, distributed_option): |
| | | if args.dataset_type == "small" and args.train_data_path_and_name_and_type is not None: |
| | | return |
| | | if args.dataset_type == "large" and args.train_data_file is not None: |
| | | return |
| | | distributed = distributed_option.distributed |
| | | if not hasattr(args, "train_set"): |
| | | args.train_set = "train" |
| | | if not hasattr(args, "dev_set"): |
| | | args.dev_set = "validation" |
| | | if not distributed or distributed_option.dist_rank == 0: |
| | | filter_wav_text(args.data_dir, args.train_set) |
| | | filter_wav_text(args.data_dir, args.dev_set) |
| | | filter_wav_text(args.data_dir, args.valid_set) |
| | | |
| | | if args.dataset_type == "small" and args.train_shape_file is None: |
| | | calc_shape(args, args.train_set) |
| | | calc_shape(args, args.dev_set) |
| | | calc_shape(args, args.valid_set) |
| | | |
| | | if args.dataset_type == "large" and args.train_data_file is None: |
| | | generate_data_list(args.data_dir, args.train_set) |
| | | generate_data_list(args.data_dir, args.dev_set) |
| | | generate_data_list(args.data_dir, args.valid_set) |
| | | |
| | | args.train_shape_file = [os.path.join(args.data_dir, args.train_set, "speech_shape")] |
| | | args.valid_shape_file = [os.path.join(args.data_dir, args.dev_set, "speech_shape")] |
| | | args.train_data_file = os.path.join(args.data_dir, args.train_set, "data.list") |
| | | args.valid_data_file = os.path.join(args.data_dir, args.dev_set, "data.list") |
| | | if args.dataset_type == "small": |
| | | args.train_shape_file = [os.path.join(args.data_dir, args.train_set, "speech_shape")] |
| | | args.valid_shape_file = [os.path.join(args.data_dir, args.valid_set, "speech_shape")] |
| | | data_names = args.dataset_conf.get("data_names", "speech,text").split(",") |
| | | data_types = args.dataset_conf.get("data_types", "sound,text").split(",") |
| | | args.train_data_path_and_name_and_type = [ |
| | | ["{}/{}/wav.scp".format(args.data_dir, args.train_set), data_names[0], data_types[0]], |
| | | ["{}/{}/text".format(args.data_dir, args.train_set), data_names[1], data_types[1]] |
| | | ] |
| | | args.valid_data_path_and_name_and_type = [ |
| | | ["{}/{}/wav.scp".format(args.data_dir, args.valid_set), data_names[0], data_types[0]], |
| | | ["{}/{}/text".format(args.data_dir, args.valid_set), data_names[1], data_types[1]] |
| | | ] |
| | | else: |
| | | args.train_data_file = os.path.join(args.data_dir, args.train_set, "data.list") |
| | | args.valid_data_file = os.path.join(args.data_dir, args.valid_set, "data.list") |
| | | if distributed: |
| | | dist.barrier() |