From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/models/scama/utils.py |  124 ++++++++++++++++++++---------------------
 1 files changed, 61 insertions(+), 63 deletions(-)

diff --git a/funasr/models/scama/utils.py b/funasr/models/scama/utils.py
index 4bb9d4f..c3f7bc3 100644
--- a/funasr/models/scama/utils.py
+++ b/funasr/models/scama/utils.py
@@ -1,91 +1,89 @@
 import os
-import torch
-from torch.nn import functional as F
 import yaml
+import torch
 import numpy as np
+from torch.nn import functional as F
+
 
 def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
-	if maxlen is None:
-		maxlen = lengths.max()
-	row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
-	matrix = torch.unsqueeze(lengths, dim=-1)
-	mask = row_vector < matrix
-	mask = mask.detach()
+    if maxlen is None:
+        maxlen = lengths.max()
+    row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
+    matrix = torch.unsqueeze(lengths, dim=-1)
+    mask = row_vector < matrix
+    mask = mask.detach()
 
-	return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
+    return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
+
 
 def apply_cmvn(inputs, mvn):
-	device = inputs.device
-	dtype = inputs.dtype
-	frame, dim = inputs.shape
-	meams = np.tile(mvn[0:1, :dim], (frame, 1))
-	vars = np.tile(mvn[1:2, :dim], (frame, 1))
-	inputs -= torch.from_numpy(meams).type(dtype).to(device)
-	inputs *= torch.from_numpy(vars).type(dtype).to(device)
+    device = inputs.device
+    dtype = inputs.dtype
+    frame, dim = inputs.shape
+    meams = np.tile(mvn[0:1, :dim], (frame, 1))
+    vars = np.tile(mvn[1:2, :dim], (frame, 1))
+    inputs -= torch.from_numpy(meams).type(dtype).to(device)
+    inputs *= torch.from_numpy(vars).type(dtype).to(device)
 
-	return inputs.type(torch.float32)
+    return inputs.type(torch.float32)
 
 
+def drop_and_add(
+    inputs: torch.Tensor,
+    outputs: torch.Tensor,
+    training: bool,
+    dropout_rate: float = 0.1,
+    stoch_layer_coeff: float = 1.0,
+):
 
+    outputs = F.dropout(outputs, p=dropout_rate, training=training, inplace=True)
+    outputs *= stoch_layer_coeff
 
-def drop_and_add(inputs: torch.Tensor,
-                 outputs: torch.Tensor,
-                 training: bool,
-                 dropout_rate: float = 0.1,
-                 stoch_layer_coeff: float = 1.0):
+    input_dim = inputs.size(-1)
+    output_dim = outputs.size(-1)
 
-
-
-	outputs = F.dropout(outputs, p=dropout_rate, training=training, inplace=True)
-	outputs *= stoch_layer_coeff
-
-	input_dim = inputs.size(-1)
-	output_dim = outputs.size(-1)
-
-	if input_dim == output_dim:
-		outputs += inputs
-	return outputs
+    if input_dim == output_dim:
+        outputs += inputs
+    return outputs
 
 
 def proc_tf_vocab(vocab_path):
-	with open(vocab_path, encoding="utf-8") as f:
-		token_list = [line.rstrip() for line in f]
-		if '<unk>' not in token_list:
-			token_list.append('<unk>')
-	return token_list
+    with open(vocab_path, encoding="utf-8") as f:
+        token_list = [line.rstrip() for line in f]
+        if "<unk>" not in token_list:
+            token_list.append("<unk>")
+    return token_list
 
 
 def gen_config_for_tfmodel(config_path, vocab_path, output_dir):
-	token_list = proc_tf_vocab(vocab_path)
-	with open(config_path, encoding="utf-8") as f:
-		config = yaml.safe_load(f)
-	
-	config['token_list'] = token_list
-	
-	if not os.path.exists(output_dir):
-		os.makedirs(output_dir)
-	
-	with open(os.path.join(output_dir, "config.yaml"), "w", encoding="utf-8") as f:
-		yaml_no_alias_safe_dump(config, f, indent=4, sort_keys=False)
+    token_list = proc_tf_vocab(vocab_path)
+    with open(config_path, encoding="utf-8") as f:
+        config = yaml.safe_load(f)
+
+    config["token_list"] = token_list
+
+    if not os.path.exists(output_dir):
+        os.makedirs(output_dir)
+
+    with open(os.path.join(output_dir, "config.yaml"), "w", encoding="utf-8") as f:
+        yaml_no_alias_safe_dump(config, f, indent=4, sort_keys=False)
 
 
 class NoAliasSafeDumper(yaml.SafeDumper):
-	# Disable anchor/alias in yaml because looks ugly
-	def ignore_aliases(self, data):
-		return True
+    # Disable anchor/alias in yaml because looks ugly
+    def ignore_aliases(self, data):
+        return True
 
 
 def yaml_no_alias_safe_dump(data, stream=None, **kwargs):
-	"""Safe-dump in yaml with no anchor/alias"""
-	return yaml.dump(
-		data, stream, allow_unicode=True, Dumper=NoAliasSafeDumper, **kwargs
-	)
+    """Safe-dump in yaml with no anchor/alias"""
+    return yaml.dump(data, stream, allow_unicode=True, Dumper=NoAliasSafeDumper, **kwargs)
 
 
-if __name__ == '__main__':
-	import sys
-	
-	config_path = sys.argv[1]
-	vocab_path = sys.argv[2]
-	output_dir = sys.argv[3]
-	gen_config_for_tfmodel(config_path, vocab_path, output_dir)
\ No newline at end of file
+if __name__ == "__main__":
+    import sys
+
+    config_path = sys.argv[1]
+    vocab_path = sys.argv[2]
+    output_dir = sys.argv[3]
+    gen_config_for_tfmodel(config_path, vocab_path, output_dir)

--
Gitblit v1.9.1