From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交
---
funasr/models/transformer/utils/nets_utils.py | 66 ++++++++-------------------------
1 files changed, 16 insertions(+), 50 deletions(-)
diff --git a/funasr/models/transformer/utils/nets_utils.py b/funasr/models/transformer/utils/nets_utils.py
index 0beb083..29d23ee 100644
--- a/funasr/models/transformer/utils/nets_utils.py
+++ b/funasr/models/transformer/utils/nets_utils.py
@@ -25,9 +25,7 @@
elif isinstance(m, torch.Tensor):
device = m.device
else:
- raise TypeError(
- "Expected torch.nn.Module or torch.tensor, " f"bot got: {type(m)}"
- )
+ raise TypeError("Expected torch.nn.Module or torch.tensor, " f"bot got: {type(m)}")
return x.to(device)
@@ -215,9 +213,7 @@
if length_dim < 0:
length_dim = xs.dim() + length_dim
# ind = (:, None, ..., None, :, , None, ..., None)
- ind = tuple(
- slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
- )
+ ind = tuple(slice(None) if i in (0, length_dim) else None for i in range(xs.dim()))
mask = mask[ind].expand_as(xs).to(xs.device)
return mask
@@ -342,29 +338,6 @@
return ret
-def th_accuracy(pad_outputs, pad_targets, ignore_label):
- """Calculate accuracy.
-
- Args:
- pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
- pad_targets (LongTensor): Target label tensors (B, Lmax).
- ignore_label (int): Ignore label id.
-
- Returns:
- float: Accuracy value (0.0 - 1.0).
-
- """
- pad_pred = pad_outputs.view(
- pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)
- ).argmax(2)
- mask = pad_targets != ignore_label
- numerator = torch.sum(
- pad_pred.masked_select(mask) == pad_targets.masked_select(mask)
- )
- denominator = torch.sum(mask)
- return float(numerator) / float(denominator)
-
-
def to_torch_tensor(x):
"""Change to torch.Tensor or ComplexTensor from numpy.ndarray.
@@ -455,9 +428,9 @@
return subsample
elif (
- (mode == "asr" and arch in ("rnn", "rnn-t"))
- or (mode == "mt" and arch == "rnn")
- or (mode == "st" and arch == "rnn")
+ (mode == "asr" and arch in ("rnn", "rnn-t"))
+ or (mode == "mt" and arch == "rnn")
+ or (mode == "st" and arch == "rnn")
):
subsample = np.ones(train_args.elayers + 1, dtype=np.int32)
if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
@@ -473,14 +446,10 @@
return subsample
elif mode == "asr" and arch == "rnn_mix":
- subsample = np.ones(
- train_args.elayers_sd + train_args.elayers + 1, dtype=np.int32
- )
+ subsample = np.ones(train_args.elayers_sd + train_args.elayers + 1, dtype=np.int32)
if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
ss = train_args.subsample.split("_")
- for j in range(
- min(train_args.elayers_sd + train_args.elayers + 1, len(ss))
- ):
+ for j in range(min(train_args.elayers_sd + train_args.elayers + 1, len(ss))):
subsample[j] = int(ss[j])
else:
logging.warning(
@@ -494,9 +463,7 @@
subsample_list = []
for idx in range(train_args.num_encs):
subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int32)
- if train_args.etype[idx].endswith("p") and not train_args.etype[
- idx
- ].startswith("vgg"):
+ if train_args.etype[idx].endswith("p") and not train_args.etype[idx].startswith("vgg"):
ss = train_args.subsample[idx].split("_")
for j in range(min(train_args.elayers[idx] + 1, len(ss))):
subsample[j] = int(ss[j])
@@ -514,9 +481,7 @@
raise ValueError("Invalid options: mode={}, arch={}".format(mode, arch))
-def rename_state_dict(
- old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]
-):
+def rename_state_dict(old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]):
"""Replace keys of old prefix with new prefix in state dict."""
# need this list not to break the dict iterator
old_keys = [k for k in state_dict if k.startswith(old_prefix)]
@@ -526,6 +491,7 @@
v = state_dict.pop(k)
new_k = k.replace(old_prefix, new_prefix)
state_dict[new_k] = v
+
class Swish(torch.nn.Module):
"""Swish activation definition.
@@ -561,6 +527,7 @@
"""Forward computation."""
return self.swish(x)
+
def get_activation(act):
"""Return activation function."""
@@ -573,6 +540,7 @@
}
return activation_funcs[act]()
+
class TooShortUttError(Exception):
"""Raised when the utt is too short for subsampling.
@@ -634,9 +602,7 @@
elif sub_factor == 6:
return 5, 3, (((input_size - 1) // 2 - 2) // 3)
else:
- raise ValueError(
- "subsampling_factor parameter should be set to either 2, 4 or 6."
- )
+ raise ValueError("subsampling_factor parameter should be set to either 2, 4 or 6.")
def make_chunk_mask(
@@ -671,6 +637,7 @@
mask[i, start:end] = True
return ~mask
+
def make_source_mask(lengths: torch.Tensor) -> torch.Tensor:
"""Create source mask for given lengths.
@@ -756,6 +723,7 @@
return decoder_in, target, t_len, u_len
+
def pad_to_len(t: torch.Tensor, pad_len: int, dim: int):
"""Pad the tensor `t` at `dim` to the length `pad_len` with right padding zeros."""
if t.size(dim) == pad_len:
@@ -763,6 +731,4 @@
else:
pad_size = list(t.shape)
pad_size[dim] = pad_len - t.size(dim)
- return torch.cat(
- [t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim
- )
+ return torch.cat([t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim)
--
Gitblit v1.9.1