From 1233c0d3ff9cf7fd6131862e7d0b208d3981f6da Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期一, 15 一月 2024 20:34:47 +0800
Subject: [PATCH] code update
---
funasr/models/scama/utils.py | 107 +++++++++++++++++++++++++++--------------------------
1 files changed, 54 insertions(+), 53 deletions(-)
diff --git a/funasr/models/scama/utils.py b/funasr/models/scama/utils.py
index 4bb9d4f..8832596 100644
--- a/funasr/models/scama/utils.py
+++ b/funasr/models/scama/utils.py
@@ -1,29 +1,30 @@
import os
-import torch
-from torch.nn import functional as F
import yaml
+import torch
import numpy as np
+from torch.nn import functional as F
+
def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
- if maxlen is None:
- maxlen = lengths.max()
- row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
- matrix = torch.unsqueeze(lengths, dim=-1)
- mask = row_vector < matrix
- mask = mask.detach()
+ if maxlen is None:
+ maxlen = lengths.max()
+ row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
+ matrix = torch.unsqueeze(lengths, dim=-1)
+ mask = row_vector < matrix
+ mask = mask.detach()
- return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
+ return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
def apply_cmvn(inputs, mvn):
- device = inputs.device
- dtype = inputs.dtype
- frame, dim = inputs.shape
- meams = np.tile(mvn[0:1, :dim], (frame, 1))
- vars = np.tile(mvn[1:2, :dim], (frame, 1))
- inputs -= torch.from_numpy(meams).type(dtype).to(device)
- inputs *= torch.from_numpy(vars).type(dtype).to(device)
+ device = inputs.device
+ dtype = inputs.dtype
+ frame, dim = inputs.shape
+ meams = np.tile(mvn[0:1, :dim], (frame, 1))
+ vars = np.tile(mvn[1:2, :dim], (frame, 1))
+ inputs -= torch.from_numpy(meams).type(dtype).to(device)
+ inputs *= torch.from_numpy(vars).type(dtype).to(device)
- return inputs.type(torch.float32)
+ return inputs.type(torch.float32)
@@ -36,56 +37,56 @@
- outputs = F.dropout(outputs, p=dropout_rate, training=training, inplace=True)
- outputs *= stoch_layer_coeff
+ outputs = F.dropout(outputs, p=dropout_rate, training=training, inplace=True)
+ outputs *= stoch_layer_coeff
- input_dim = inputs.size(-1)
- output_dim = outputs.size(-1)
+ input_dim = inputs.size(-1)
+ output_dim = outputs.size(-1)
- if input_dim == output_dim:
- outputs += inputs
- return outputs
+ if input_dim == output_dim:
+ outputs += inputs
+ return outputs
def proc_tf_vocab(vocab_path):
- with open(vocab_path, encoding="utf-8") as f:
- token_list = [line.rstrip() for line in f]
- if '<unk>' not in token_list:
- token_list.append('<unk>')
- return token_list
+ with open(vocab_path, encoding="utf-8") as f:
+ token_list = [line.rstrip() for line in f]
+ if '<unk>' not in token_list:
+ token_list.append('<unk>')
+ return token_list
def gen_config_for_tfmodel(config_path, vocab_path, output_dir):
- token_list = proc_tf_vocab(vocab_path)
- with open(config_path, encoding="utf-8") as f:
- config = yaml.safe_load(f)
-
- config['token_list'] = token_list
-
- if not os.path.exists(output_dir):
- os.makedirs(output_dir)
-
- with open(os.path.join(output_dir, "config.yaml"), "w", encoding="utf-8") as f:
- yaml_no_alias_safe_dump(config, f, indent=4, sort_keys=False)
+ token_list = proc_tf_vocab(vocab_path)
+ with open(config_path, encoding="utf-8") as f:
+ config = yaml.safe_load(f)
+
+ config['token_list'] = token_list
+
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+
+ with open(os.path.join(output_dir, "config.yaml"), "w", encoding="utf-8") as f:
+ yaml_no_alias_safe_dump(config, f, indent=4, sort_keys=False)
class NoAliasSafeDumper(yaml.SafeDumper):
- # Disable anchor/alias in yaml because looks ugly
- def ignore_aliases(self, data):
- return True
+ # Disable anchor/alias in yaml because looks ugly
+ def ignore_aliases(self, data):
+ return True
def yaml_no_alias_safe_dump(data, stream=None, **kwargs):
- """Safe-dump in yaml with no anchor/alias"""
- return yaml.dump(
- data, stream, allow_unicode=True, Dumper=NoAliasSafeDumper, **kwargs
- )
+ """Safe-dump in yaml with no anchor/alias"""
+ return yaml.dump(
+ data, stream, allow_unicode=True, Dumper=NoAliasSafeDumper, **kwargs
+ )
if __name__ == '__main__':
- import sys
-
- config_path = sys.argv[1]
- vocab_path = sys.argv[2]
- output_dir = sys.argv[3]
- gen_config_for_tfmodel(config_path, vocab_path, output_dir)
\ No newline at end of file
+ import sys
+
+ config_path = sys.argv[1]
+ vocab_path = sys.argv[2]
+ output_dir = sys.argv[3]
+ gen_config_for_tfmodel(config_path, vocab_path, output_dir)
\ No newline at end of file
--
Gitblit v1.9.1