游雁
2024-03-27 9b4e9cc8a0311e5243d69b73ed073e7ea441982e
funasr/bin/compute_audio_cmvn.py
@@ -18,7 +18,7 @@
    assert "model" in kwargs
    if "model_conf" not in kwargs:
        logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
        logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms")))
        kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
    
@@ -28,7 +28,7 @@
def main(**kwargs):
    print(kwargs)
    # set random seed
    tables.print()
    # tables.print()
    set_all_random_seed(kwargs.get("seed", 0))
    torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
    torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
@@ -54,21 +54,15 @@
    dataset_train = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=None, is_training=False, **kwargs.get("dataset_conf"))
    # dataloader
    batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
    batch_sampler_train = None
    if batch_sampler is not None:
        batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
        dataset_conf = kwargs.get("dataset_conf")
        dataset_conf["batch_type"] = "example"
        dataset_conf["batch_size"] = 1
        batch_sampler_train = batch_sampler_class(dataset_train, is_training=False, **dataset_conf)
    batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "BatchSampler")
    batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
    dataset_conf = kwargs.get("dataset_conf")
    dataset_conf["batch_type"] = "example"
    dataset_conf["batch_size"] = 1
    dataset_conf["num_workers"] = os.cpu_count() or 32
    batch_sampler_train = batch_sampler_class(dataset_train, is_training=False, **dataset_conf)
    dataloader_train = torch.utils.data.DataLoader(dataset_train,
                                                collate_fn=dataset_train.collator,
                                                batch_sampler=batch_sampler_train,
                                                num_workers=int(kwargs.get("dataset_conf").get("num_workers", 4)),
                                                pin_memory=True)
    dataloader_train = torch.utils.data.DataLoader(dataset_train, collate_fn=dataset_train.collator, **batch_sampler_train)
    
    iter_stop = int(kwargs.get("scale", 1.0)*len(dataloader_train))
@@ -79,8 +73,8 @@
        fbank = batch["speech"].numpy()[0, :, :]
        if total_frames == 0:
            mean_stats = fbank
            var_stats = np.square(fbank)
            mean_stats = np.sum(fbank, axis=0)
            var_stats = np.sum(np.square(fbank), axis=0)
        else:
            mean_stats += np.sum(fbank, axis=0)
            var_stats += np.sum(np.square(fbank), axis=0)
@@ -93,6 +87,7 @@
        'total_frames': total_frames
    }
    cmvn_file = kwargs.get("cmvn_file", "cmvn.json")
    # import pdb;pdb.set_trace()
    with open(cmvn_file, 'w') as fout:
        fout.write(json.dumps(cmvn_info))
        
@@ -110,14 +105,14 @@
        fout.write("</Nnet>" + '\n')
    
    
"""
python funasr/bin/compute_audio_cmvn.py \
--config-path "/Users/zhifu/funasr1.0/examples/aishell/paraformer/conf" \
--config-name "train_asr_paraformer_conformer_12e_6d_2048_256.yaml" \
++train_data_set_list="/Users/zhifu/funasr1.0/data/list/audio_datasets.jsonl" \
++cmvn_file="/Users/zhifu/funasr1.0/data/list/cmvn.json" \
++dataset_conf.num_workers=0
"""
if __name__ == "__main__":
    main_hydra()
    """
    python funasr/bin/compute_status.py \
    --config-path "/Users/zhifu/funasr1.0/examples/aishell/conf" \
    --config-name "train_asr_paraformer_conformer_12e_6d_2048_256.yaml" \
    ++train_data_set_list="/Users/zhifu/funasr1.0/data/list/audio_datasets.jsonl" \
    ++cmvn_file="/Users/zhifu/funasr1.0/data/list/cmvn.json" \
    ++dataset_conf.num_workers=32
    """