From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/bin/compute_audio_cmvn.py |   84 ++++++++++++++++++++++++++----------------
 1 files changed, 52 insertions(+), 32 deletions(-)

diff --git a/funasr/bin/compute_audio_cmvn.py b/funasr/bin/compute_audio_cmvn.py
index cd64329..79c94c6 100644
--- a/funasr/bin/compute_audio_cmvn.py
+++ b/funasr/bin/compute_audio_cmvn.py
@@ -7,20 +7,21 @@
 from omegaconf import DictConfig, OmegaConf
 
 from funasr.register import tables
-from funasr.download.download_from_hub import download_model
+from funasr.download.download_model_from_hub import download_model
 from funasr.train_utils.set_all_random_seed import set_all_random_seed
 
 
 @hydra.main(config_name=None, version_base=None)
 def main_hydra(kwargs: DictConfig):
     if kwargs.get("debug", False):
-        import pdb; pdb.set_trace()
+        import pdb
+
+        pdb.set_trace()
 
     assert "model" in kwargs
     if "model_conf" not in kwargs:
         logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms")))
         kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
-    
 
     main(**kwargs)
 
@@ -33,12 +34,9 @@
     torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
     torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
     torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
-    
-
-
 
     tokenizer = kwargs.get("tokenizer", None)
-    
+
     # build frontend if frontend is none None
     frontend = kwargs.get("frontend", None)
     if frontend is not None:
@@ -47,11 +45,15 @@
         kwargs["frontend"] = frontend
         kwargs["input_size"] = frontend.output_size()
 
-
-
     # dataset
     dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
-    dataset_train = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=None, is_training=False, **kwargs.get("dataset_conf"))
+    dataset_train = dataset_class(
+        kwargs.get("train_data_set_list"),
+        frontend=frontend,
+        tokenizer=None,
+        is_training=False,
+        **kwargs.get("dataset_conf"),
+    )
 
     # dataloader
     batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "BatchSampler")
@@ -62,13 +64,18 @@
     dataset_conf["num_workers"] = os.cpu_count() or 32
     batch_sampler_train = batch_sampler_class(dataset_train, is_training=False, **dataset_conf)
 
-    dataloader_train = torch.utils.data.DataLoader(dataset_train, collate_fn=dataset_train.collator, **batch_sampler_train)
-    
-    iter_stop = int(kwargs.get("scale", 1.0)*len(dataloader_train))
+    dataloader_train = torch.utils.data.DataLoader(
+        dataset_train, collate_fn=dataset_train.collator, **batch_sampler_train
+    )
 
     total_frames = 0
     for batch_idx, batch in enumerate(dataloader_train):
-        if batch_idx >= iter_stop:
+        iter_stop = int(kwargs.get("scale", -1.0) * len(dataloader_train))
+        log_step = iter_stop // 100
+        if batch_idx % log_step == 0:
+            logging.info(f"prcessed: {batch_idx}/{iter_stop}")
+        if batch_idx >= iter_stop and iter_stop > 0.0:
+            logging.info(f"prcessed: {iter_stop}/{iter_stop}")
             break
 
         fbank = batch["speech"].numpy()[0, :, :]
@@ -79,33 +86,46 @@
             mean_stats += np.sum(fbank, axis=0)
             var_stats += np.sum(np.square(fbank), axis=0)
         total_frames += fbank.shape[0]
-        
-        
+
     cmvn_info = {
-        'mean_stats': list(mean_stats.tolist()),
-        'var_stats': list(var_stats.tolist()),
-        'total_frames': total_frames
+        "mean_stats": list(mean_stats.tolist()),
+        "var_stats": list(var_stats.tolist()),
+        "total_frames": total_frames,
     }
     cmvn_file = kwargs.get("cmvn_file", "cmvn.json")
     # import pdb;pdb.set_trace()
-    with open(cmvn_file, 'w') as fout:
+    with open(cmvn_file, "w") as fout:
         fout.write(json.dumps(cmvn_info))
-        
+
     mean = -1.0 * mean_stats / total_frames
     var = 1.0 / np.sqrt(var_stats / total_frames - mean * mean)
     dims = mean.shape[0]
     am_mvn = os.path.dirname(cmvn_file) + "/am.mvn"
-    with open(am_mvn, 'w') as fout:
-        fout.write("<Nnet>" + "\n" + "<Splice> " + str(dims) + " " + str(dims) + '\n' + "[ 0 ]" + "\n" + "<AddShift> " + str(dims) + " " + str(dims) + "\n")
-        mean_str = str(list(mean)).replace(',', '').replace('[', '[ ').replace(']', ' ]')
-        fout.write("<LearnRateCoef> 0 " + mean_str + '\n')
-        fout.write("<Rescale> " + str(dims) + " " + str(dims) + '\n')
-        var_str = str(list(var)).replace(',', '').replace('[', '[ ').replace(']', ' ]')
-        fout.write("<LearnRateCoef> 0 " + var_str + '\n')
-        fout.write("</Nnet>" + '\n')
-    
-    
-    
+    with open(am_mvn, "w") as fout:
+        fout.write(
+            "<Nnet>"
+            + "\n"
+            + "<Splice> "
+            + str(dims)
+            + " "
+            + str(dims)
+            + "\n"
+            + "[ 0 ]"
+            + "\n"
+            + "<AddShift> "
+            + str(dims)
+            + " "
+            + str(dims)
+            + "\n"
+        )
+        mean_str = str(list(mean)).replace(",", "").replace("[", "[ ").replace("]", " ]")
+        fout.write("<LearnRateCoef> 0 " + mean_str + "\n")
+        fout.write("<Rescale> " + str(dims) + " " + str(dims) + "\n")
+        var_str = str(list(var)).replace(",", "").replace("[", "[ ").replace("]", " ]")
+        fout.write("<LearnRateCoef> 0 " + var_str + "\n")
+        fout.write("</Nnet>" + "\n")
+
+
 """
 python funasr/bin/compute_audio_cmvn.py \
 --config-path "/Users/zhifu/funasr1.0/examples/aishell/paraformer/conf" \

--
Gitblit v1.9.1