From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交

---
 funasr/models/transformer/utils/nets_utils.py |   66 ++++++++-------------------------
 1 files changed, 16 insertions(+), 50 deletions(-)

diff --git a/funasr/models/transformer/utils/nets_utils.py b/funasr/models/transformer/utils/nets_utils.py
index 0beb083..29d23ee 100644
--- a/funasr/models/transformer/utils/nets_utils.py
+++ b/funasr/models/transformer/utils/nets_utils.py
@@ -25,9 +25,7 @@
     elif isinstance(m, torch.Tensor):
         device = m.device
     else:
-        raise TypeError(
-            "Expected torch.nn.Module or torch.tensor, " f"bot got: {type(m)}"
-        )
+        raise TypeError("Expected torch.nn.Module or torch.tensor, " f"bot got: {type(m)}")
     return x.to(device)
 
 
@@ -215,9 +213,7 @@
         if length_dim < 0:
             length_dim = xs.dim() + length_dim
         # ind = (:, None, ..., None, :, , None, ..., None)
-        ind = tuple(
-            slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
-        )
+        ind = tuple(slice(None) if i in (0, length_dim) else None for i in range(xs.dim()))
         mask = mask[ind].expand_as(xs).to(xs.device)
     return mask
 
@@ -342,29 +338,6 @@
     return ret
 
 
-def th_accuracy(pad_outputs, pad_targets, ignore_label):
-    """Calculate accuracy.
-
-    Args:
-        pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
-        pad_targets (LongTensor): Target label tensors (B, Lmax).
-        ignore_label (int): Ignore label id.
-
-    Returns:
-        float: Accuracy value (0.0 - 1.0).
-
-    """
-    pad_pred = pad_outputs.view(
-        pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)
-    ).argmax(2)
-    mask = pad_targets != ignore_label
-    numerator = torch.sum(
-        pad_pred.masked_select(mask) == pad_targets.masked_select(mask)
-    )
-    denominator = torch.sum(mask)
-    return float(numerator) / float(denominator)
-
-
 def to_torch_tensor(x):
     """Change to torch.Tensor or ComplexTensor from numpy.ndarray.
 
@@ -455,9 +428,9 @@
         return subsample
 
     elif (
-            (mode == "asr" and arch in ("rnn", "rnn-t"))
-            or (mode == "mt" and arch == "rnn")
-            or (mode == "st" and arch == "rnn")
+        (mode == "asr" and arch in ("rnn", "rnn-t"))
+        or (mode == "mt" and arch == "rnn")
+        or (mode == "st" and arch == "rnn")
     ):
         subsample = np.ones(train_args.elayers + 1, dtype=np.int32)
         if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
@@ -473,14 +446,10 @@
         return subsample
 
     elif mode == "asr" and arch == "rnn_mix":
-        subsample = np.ones(
-            train_args.elayers_sd + train_args.elayers + 1, dtype=np.int32
-        )
+        subsample = np.ones(train_args.elayers_sd + train_args.elayers + 1, dtype=np.int32)
         if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
             ss = train_args.subsample.split("_")
-            for j in range(
-                    min(train_args.elayers_sd + train_args.elayers + 1, len(ss))
-            ):
+            for j in range(min(train_args.elayers_sd + train_args.elayers + 1, len(ss))):
                 subsample[j] = int(ss[j])
         else:
             logging.warning(
@@ -494,9 +463,7 @@
         subsample_list = []
         for idx in range(train_args.num_encs):
             subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int32)
-            if train_args.etype[idx].endswith("p") and not train_args.etype[
-                idx
-            ].startswith("vgg"):
+            if train_args.etype[idx].endswith("p") and not train_args.etype[idx].startswith("vgg"):
                 ss = train_args.subsample[idx].split("_")
                 for j in range(min(train_args.elayers[idx] + 1, len(ss))):
                     subsample[j] = int(ss[j])
@@ -514,9 +481,7 @@
         raise ValueError("Invalid options: mode={}, arch={}".format(mode, arch))
 
 
-def rename_state_dict(
-        old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]
-):
+def rename_state_dict(old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]):
     """Replace keys of old prefix with new prefix in state dict."""
     # need this list not to break the dict iterator
     old_keys = [k for k in state_dict if k.startswith(old_prefix)]
@@ -526,6 +491,7 @@
         v = state_dict.pop(k)
         new_k = k.replace(old_prefix, new_prefix)
         state_dict[new_k] = v
+
 
 class Swish(torch.nn.Module):
     """Swish activation definition.
@@ -561,6 +527,7 @@
         """Forward computation."""
         return self.swish(x)
 
+
 def get_activation(act):
     """Return activation function."""
 
@@ -573,6 +540,7 @@
     }
 
     return activation_funcs[act]()
+
 
 class TooShortUttError(Exception):
     """Raised when the utt is too short for subsampling.
@@ -634,9 +602,7 @@
     elif sub_factor == 6:
         return 5, 3, (((input_size - 1) // 2 - 2) // 3)
     else:
-        raise ValueError(
-            "subsampling_factor parameter should be set to either 2, 4 or 6."
-        )
+        raise ValueError("subsampling_factor parameter should be set to either 2, 4 or 6.")
 
 
 def make_chunk_mask(
@@ -671,6 +637,7 @@
         mask[i, start:end] = True
 
     return ~mask
+
 
 def make_source_mask(lengths: torch.Tensor) -> torch.Tensor:
     """Create source mask for given lengths.
@@ -756,6 +723,7 @@
 
     return decoder_in, target, t_len, u_len
 
+
 def pad_to_len(t: torch.Tensor, pad_len: int, dim: int):
     """Pad the tensor `t` at `dim` to the length `pad_len` with right padding zeros."""
     if t.size(dim) == pad_len:
@@ -763,6 +731,4 @@
     else:
         pad_size = list(t.shape)
         pad_size[dim] = pad_len - t.size(dim)
-        return torch.cat(
-            [t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim
-        )
+        return torch.cat([t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim)

--
Gitblit v1.9.1