| | |
| | | from omegaconf import DictConfig, OmegaConf |
| | | |
| | | from funasr.register import tables |
| | | from funasr.download.download_from_hub import download_model |
| | | from funasr.download.download_model_from_hub import download_model |
| | | from funasr.train_utils.set_all_random_seed import set_all_random_seed |
| | | |
| | | |
| | | @hydra.main(config_name=None, version_base=None) |
| | | def main_hydra(kwargs: DictConfig): |
| | | if kwargs.get("debug", False): |
| | | import pdb; pdb.set_trace() |
| | | import pdb |
| | | |
| | | pdb.set_trace() |
| | | |
| | | assert "model" in kwargs |
| | | if "model_conf" not in kwargs: |
| | | logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms"))) |
| | | logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms"))) |
| | | kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs) |
| | | |
| | | |
| | | main(**kwargs) |
| | | |
| | |
| | | def main(**kwargs): |
| | | print(kwargs) |
| | | # set random seed |
| | | tables.print() |
| | | # tables.print() |
| | | set_all_random_seed(kwargs.get("seed", 0)) |
| | | torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled) |
| | | torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark) |
| | | torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True) |
| | | |
| | | |
| | | |
| | | |
| | | tokenizer = kwargs.get("tokenizer", None) |
| | | |
| | | |
| | | # build frontend if frontend is none None |
| | | frontend = kwargs.get("frontend", None) |
| | | if frontend is not None: |
| | |
| | | kwargs["frontend"] = frontend |
| | | kwargs["input_size"] = frontend.output_size() |
| | | |
| | | |
| | | |
| | | # dataset |
| | | dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset")) |
| | | dataset_train = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=None, is_training=False, **kwargs.get("dataset_conf")) |
| | | dataset_train = dataset_class( |
| | | kwargs.get("train_data_set_list"), |
| | | frontend=frontend, |
| | | tokenizer=None, |
| | | is_training=False, |
| | | **kwargs.get("dataset_conf"), |
| | | ) |
| | | |
| | | # dataloader |
| | | batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler") |
| | | batch_sampler_train = None |
| | | if batch_sampler is not None: |
| | | batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler) |
| | | dataset_conf = kwargs.get("dataset_conf") |
| | | dataset_conf["batch_type"] = "example" |
| | | dataset_conf["batch_size"] = 1 |
| | | batch_sampler_train = batch_sampler_class(dataset_train, is_training=False, **dataset_conf) |
| | | batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "BatchSampler") |
| | | batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler) |
| | | dataset_conf = kwargs.get("dataset_conf") |
| | | dataset_conf["batch_type"] = "example" |
| | | dataset_conf["batch_size"] = 1 |
| | | dataset_conf["num_workers"] = os.cpu_count() or 32 |
| | | batch_sampler_train = batch_sampler_class(dataset_train, is_training=False, **dataset_conf) |
| | | |
| | | |
| | | dataloader_train = torch.utils.data.DataLoader(dataset_train, |
| | | collate_fn=dataset_train.collator, |
| | | batch_sampler=batch_sampler_train, |
| | | num_workers=int(kwargs.get("dataset_conf").get("num_workers", 4)), |
| | | pin_memory=True) |
| | | |
| | | iter_stop = int(kwargs.get("scale", 1.0)*len(dataloader_train)) |
| | | dataloader_train = torch.utils.data.DataLoader( |
| | | dataset_train, collate_fn=dataset_train.collator, **batch_sampler_train |
| | | ) |
| | | |
| | | total_frames = 0 |
| | | for batch_idx, batch in enumerate(dataloader_train): |
| | | if batch_idx >= iter_stop: |
| | | iter_stop = int(kwargs.get("scale", -1.0) * len(dataloader_train)) |
| | | log_step = iter_stop // 100 |
| | | if batch_idx % log_step == 0: |
| | | logging.info(f"prcessed: {batch_idx}/{iter_stop}") |
| | | if batch_idx >= iter_stop and iter_stop > 0.0: |
| | | logging.info(f"prcessed: {iter_stop}/{iter_stop}") |
| | | break |
| | | |
| | | fbank = batch["speech"].numpy()[0, :, :] |
| | |
| | | mean_stats += np.sum(fbank, axis=0) |
| | | var_stats += np.sum(np.square(fbank), axis=0) |
| | | total_frames += fbank.shape[0] |
| | | |
| | | |
| | | |
| | | cmvn_info = { |
| | | 'mean_stats': list(mean_stats.tolist()), |
| | | 'var_stats': list(var_stats.tolist()), |
| | | 'total_frames': total_frames |
| | | "mean_stats": mean_stats.tolist(), |
| | | "var_stats": var_stats.tolist(), |
| | | "total_frames": total_frames, |
| | | } |
| | | cmvn_file = kwargs.get("cmvn_file", "cmvn.json") |
| | | # import pdb;pdb.set_trace() |
| | | with open(cmvn_file, 'w') as fout: |
| | | with open(cmvn_file, "w") as fout: |
| | | fout.write(json.dumps(cmvn_info)) |
| | | |
| | | |
| | | mean = -1.0 * mean_stats / total_frames |
| | | var = 1.0 / np.sqrt(var_stats / total_frames - mean * mean) |
| | | dims = mean.shape[0] |
| | | am_mvn = os.path.dirname(cmvn_file) + "/am.mvn" |
| | | with open(am_mvn, 'w') as fout: |
| | | fout.write("<Nnet>" + "\n" + "<Splice> " + str(dims) + " " + str(dims) + '\n' + "[ 0 ]" + "\n" + "<AddShift> " + str(dims) + " " + str(dims) + "\n") |
| | | mean_str = str(list(mean)).replace(',', '').replace('[', '[ ').replace(']', ' ]') |
| | | fout.write("<LearnRateCoef> 0 " + mean_str + '\n') |
| | | fout.write("<Rescale> " + str(dims) + " " + str(dims) + '\n') |
| | | var_str = str(list(var)).replace(',', '').replace('[', '[ ').replace(']', ' ]') |
| | | fout.write("<LearnRateCoef> 0 " + var_str + '\n') |
| | | fout.write("</Nnet>" + '\n') |
| | | |
| | | |
| | | |
| | | with open(am_mvn, "w") as fout: |
| | | fout.write( |
| | | "<Nnet>" |
| | | + "\n" |
| | | + "<Splice> " |
| | | + str(dims) |
| | | + " " |
| | | + str(dims) |
| | | + "\n" |
| | | + "[ 0 ]" |
| | | + "\n" |
| | | + "<AddShift> " |
| | | + str(dims) |
| | | + " " |
| | | + str(dims) |
| | | + "\n" |
| | | ) |
| | | fout.write("<LearnRateCoef> 0 [ " + " ".join([str(item) for item in mean]) + " ]\n") |
| | | fout.write("<Rescale> " + str(dims) + " " + str(dims) + "\n") |
| | | fout.write("<LearnRateCoef> 0 [ " + " ".join([str(item) for item in var]) + " ]\n") |
| | | fout.write("</Nnet>" + "\n") |
| | | |
| | | |
| | | """ |
| | | python funasr/bin/compute_audio_cmvn.py \ |
| | | --config-path "/Users/zhifu/funasr1.0/examples/aishell/paraformer/conf" \ |