| | |
| | | from omegaconf import DictConfig, OmegaConf |
| | | |
| | | from funasr.register import tables |
| | | from funasr.download.download_from_hub import download_model |
| | | from funasr.download.download_model_from_hub import download_model |
| | | from funasr.train_utils.set_all_random_seed import set_all_random_seed |
| | | |
| | | |
| | |
| | | frontend=frontend, |
| | | tokenizer=None, |
| | | is_training=False, |
| | | **kwargs.get("dataset_conf") |
| | | **kwargs.get("dataset_conf"), |
| | | ) |
| | | |
| | | # dataloader |
| | |
| | | dataset_train, collate_fn=dataset_train.collator, **batch_sampler_train |
| | | ) |
| | | |
| | | iter_stop = int(kwargs.get("scale", 1.0) * len(dataloader_train)) |
| | | |
| | | total_frames = 0 |
| | | for batch_idx, batch in enumerate(dataloader_train): |
| | | if batch_idx >= iter_stop: |
| | | iter_stop = int(kwargs.get("scale", -1.0) * len(dataloader_train)) |
| | | log_step = iter_stop // 100 |
| | | if batch_idx % log_step == 0: |
| | | logging.info(f"prcessed: {batch_idx}/{iter_stop}") |
| | | if batch_idx >= iter_stop and iter_stop > 0.0: |
| | | logging.info(f"prcessed: {iter_stop}/{iter_stop}") |
| | | break |
| | | |
| | | fbank = batch["speech"].numpy()[0, :, :] |
| | |
| | | total_frames += fbank.shape[0] |
| | | |
| | | cmvn_info = { |
| | | "mean_stats": list(mean_stats.tolist()), |
| | | "var_stats": list(var_stats.tolist()), |
| | | "mean_stats": mean_stats.tolist(), |
| | | "var_stats": var_stats.tolist(), |
| | | "total_frames": total_frames, |
| | | } |
| | | cmvn_file = kwargs.get("cmvn_file", "cmvn.json") |
| | |
| | | + str(dims) |
| | | + "\n" |
| | | ) |
| | | mean_str = str(list(mean)).replace(",", "").replace("[", "[ ").replace("]", " ]") |
| | | fout.write("<LearnRateCoef> 0 " + mean_str + "\n") |
| | | fout.write("<LearnRateCoef> 0 [ " + " ".join([str(item) for item in mean]) + " ]\n") |
| | | fout.write("<Rescale> " + str(dims) + " " + str(dims) + "\n") |
| | | var_str = str(list(var)).replace(",", "").replace("[", "[ ").replace("]", " ]") |
| | | fout.write("<LearnRateCoef> 0 " + var_str + "\n") |
| | | fout.write("<LearnRateCoef> 0 [ " + " ".join([str(item) for item in var]) + " ]\n") |
| | | fout.write("</Nnet>" + "\n") |
| | | |
| | | |