From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/models/scama/utils.py |   39 ++++++++++++++++++---------------------
 1 files changed, 18 insertions(+), 21 deletions(-)

diff --git a/funasr/models/scama/utils.py b/funasr/models/scama/utils.py
index 8832596..c3f7bc3 100644
--- a/funasr/models/scama/utils.py
+++ b/funasr/models/scama/utils.py
@@ -15,6 +15,7 @@
 
     return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
 
+
 def apply_cmvn(inputs, mvn):
     device = inputs.device
     dtype = inputs.dtype
@@ -27,15 +28,13 @@
     return inputs.type(torch.float32)
 
 
-
-
-def drop_and_add(inputs: torch.Tensor,
-                 outputs: torch.Tensor,
-                 training: bool,
-                 dropout_rate: float = 0.1,
-                 stoch_layer_coeff: float = 1.0):
-
-
+def drop_and_add(
+    inputs: torch.Tensor,
+    outputs: torch.Tensor,
+    training: bool,
+    dropout_rate: float = 0.1,
+    stoch_layer_coeff: float = 1.0,
+):
 
     outputs = F.dropout(outputs, p=dropout_rate, training=training, inplace=True)
     outputs *= stoch_layer_coeff
@@ -51,8 +50,8 @@
 def proc_tf_vocab(vocab_path):
     with open(vocab_path, encoding="utf-8") as f:
         token_list = [line.rstrip() for line in f]
-        if '<unk>' not in token_list:
-            token_list.append('<unk>')
+        if "<unk>" not in token_list:
+            token_list.append("<unk>")
     return token_list
 
 
@@ -60,12 +59,12 @@
     token_list = proc_tf_vocab(vocab_path)
     with open(config_path, encoding="utf-8") as f:
         config = yaml.safe_load(f)
-    
-    config['token_list'] = token_list
-    
+
+    config["token_list"] = token_list
+
     if not os.path.exists(output_dir):
         os.makedirs(output_dir)
-    
+
     with open(os.path.join(output_dir, "config.yaml"), "w", encoding="utf-8") as f:
         yaml_no_alias_safe_dump(config, f, indent=4, sort_keys=False)
 
@@ -78,15 +77,13 @@
 
 def yaml_no_alias_safe_dump(data, stream=None, **kwargs):
     """Safe-dump in yaml with no anchor/alias"""
-    return yaml.dump(
-        data, stream, allow_unicode=True, Dumper=NoAliasSafeDumper, **kwargs
-    )
+    return yaml.dump(data, stream, allow_unicode=True, Dumper=NoAliasSafeDumper, **kwargs)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     import sys
-    
+
     config_path = sys.argv[1]
     vocab_path = sys.argv[2]
     output_dir = sys.argv[3]
-    gen_config_for_tfmodel(config_path, vocab_path, output_dir)
\ No newline at end of file
+    gen_config_for_tfmodel(config_path, vocab_path, output_dir)

--
Gitblit v1.9.1