From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/bin/compute_audio_cmvn.py | 84 ++++++++++++++++++++++++++----------------
1 files changed, 52 insertions(+), 32 deletions(-)
diff --git a/funasr/bin/compute_audio_cmvn.py b/funasr/bin/compute_audio_cmvn.py
index cd64329..79c94c6 100644
--- a/funasr/bin/compute_audio_cmvn.py
+++ b/funasr/bin/compute_audio_cmvn.py
@@ -7,20 +7,21 @@
from omegaconf import DictConfig, OmegaConf
from funasr.register import tables
-from funasr.download.download_from_hub import download_model
+from funasr.download.download_model_from_hub import download_model
from funasr.train_utils.set_all_random_seed import set_all_random_seed
@hydra.main(config_name=None, version_base=None)
def main_hydra(kwargs: DictConfig):
if kwargs.get("debug", False):
- import pdb; pdb.set_trace()
+ import pdb
+
+ pdb.set_trace()
assert "model" in kwargs
if "model_conf" not in kwargs:
logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms")))
kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
-
main(**kwargs)
@@ -33,12 +34,9 @@
torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
-
-
-
tokenizer = kwargs.get("tokenizer", None)
-
+
# build frontend if frontend is none None
frontend = kwargs.get("frontend", None)
if frontend is not None:
@@ -47,11 +45,15 @@
kwargs["frontend"] = frontend
kwargs["input_size"] = frontend.output_size()
-
-
# dataset
dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
- dataset_train = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=None, is_training=False, **kwargs.get("dataset_conf"))
+ dataset_train = dataset_class(
+ kwargs.get("train_data_set_list"),
+ frontend=frontend,
+ tokenizer=None,
+ is_training=False,
+ **kwargs.get("dataset_conf"),
+ )
# dataloader
batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "BatchSampler")
@@ -62,13 +64,18 @@
dataset_conf["num_workers"] = os.cpu_count() or 32
batch_sampler_train = batch_sampler_class(dataset_train, is_training=False, **dataset_conf)
- dataloader_train = torch.utils.data.DataLoader(dataset_train, collate_fn=dataset_train.collator, **batch_sampler_train)
-
- iter_stop = int(kwargs.get("scale", 1.0)*len(dataloader_train))
+ dataloader_train = torch.utils.data.DataLoader(
+ dataset_train, collate_fn=dataset_train.collator, **batch_sampler_train
+ )
total_frames = 0
for batch_idx, batch in enumerate(dataloader_train):
- if batch_idx >= iter_stop:
+ iter_stop = int(kwargs.get("scale", -1.0) * len(dataloader_train))
+ log_step = iter_stop // 100
+ if batch_idx % log_step == 0:
+ logging.info(f"prcessed: {batch_idx}/{iter_stop}")
+ if batch_idx >= iter_stop and iter_stop > 0.0:
+ logging.info(f"prcessed: {iter_stop}/{iter_stop}")
break
fbank = batch["speech"].numpy()[0, :, :]
@@ -79,33 +86,46 @@
mean_stats += np.sum(fbank, axis=0)
var_stats += np.sum(np.square(fbank), axis=0)
total_frames += fbank.shape[0]
-
-
+
cmvn_info = {
- 'mean_stats': list(mean_stats.tolist()),
- 'var_stats': list(var_stats.tolist()),
- 'total_frames': total_frames
+ "mean_stats": list(mean_stats.tolist()),
+ "var_stats": list(var_stats.tolist()),
+ "total_frames": total_frames,
}
cmvn_file = kwargs.get("cmvn_file", "cmvn.json")
# import pdb;pdb.set_trace()
- with open(cmvn_file, 'w') as fout:
+ with open(cmvn_file, "w") as fout:
fout.write(json.dumps(cmvn_info))
-
+
mean = -1.0 * mean_stats / total_frames
var = 1.0 / np.sqrt(var_stats / total_frames - mean * mean)
dims = mean.shape[0]
am_mvn = os.path.dirname(cmvn_file) + "/am.mvn"
- with open(am_mvn, 'w') as fout:
- fout.write("<Nnet>" + "\n" + "<Splice> " + str(dims) + " " + str(dims) + '\n' + "[ 0 ]" + "\n" + "<AddShift> " + str(dims) + " " + str(dims) + "\n")
- mean_str = str(list(mean)).replace(',', '').replace('[', '[ ').replace(']', ' ]')
- fout.write("<LearnRateCoef> 0 " + mean_str + '\n')
- fout.write("<Rescale> " + str(dims) + " " + str(dims) + '\n')
- var_str = str(list(var)).replace(',', '').replace('[', '[ ').replace(']', ' ]')
- fout.write("<LearnRateCoef> 0 " + var_str + '\n')
- fout.write("</Nnet>" + '\n')
-
-
-
+ with open(am_mvn, "w") as fout:
+ fout.write(
+ "<Nnet>"
+ + "\n"
+ + "<Splice> "
+ + str(dims)
+ + " "
+ + str(dims)
+ + "\n"
+ + "[ 0 ]"
+ + "\n"
+ + "<AddShift> "
+ + str(dims)
+ + " "
+ + str(dims)
+ + "\n"
+ )
+ mean_str = str(list(mean)).replace(",", "").replace("[", "[ ").replace("]", " ]")
+ fout.write("<LearnRateCoef> 0 " + mean_str + "\n")
+ fout.write("<Rescale> " + str(dims) + " " + str(dims) + "\n")
+ var_str = str(list(var)).replace(",", "").replace("[", "[ ").replace("]", " ]")
+ fout.write("<LearnRateCoef> 0 " + var_str + "\n")
+ fout.write("</Nnet>" + "\n")
+
+
"""
python funasr/bin/compute_audio_cmvn.py \
--config-path "/Users/zhifu/funasr1.0/examples/aishell/paraformer/conf" \
--
Gitblit v1.9.1