From 0e622e694e6cb4459955f1e5942a7c53349ce640 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 19 十二月 2023 21:58:14 +0800
Subject: [PATCH] funasr2
---
funasr/models/transformer/utils/nets_utils.py | 21 ---------------------
1 files changed, 0 insertions(+), 21 deletions(-)
diff --git a/funasr/models/transformer/utils/nets_utils.py b/funasr/models/transformer/utils/nets_utils.py
index 0beb083..ce151a0 100644
--- a/funasr/models/transformer/utils/nets_utils.py
+++ b/funasr/models/transformer/utils/nets_utils.py
@@ -342,27 +342,6 @@
return ret
-def th_accuracy(pad_outputs, pad_targets, ignore_label):
- """Calculate accuracy.
-
- Args:
- pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
- pad_targets (LongTensor): Target label tensors (B, Lmax).
- ignore_label (int): Ignore label id.
-
- Returns:
- float: Accuracy value (0.0 - 1.0).
-
- """
- pad_pred = pad_outputs.view(
- pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)
- ).argmax(2)
- mask = pad_targets != ignore_label
- numerator = torch.sum(
- pad_pred.masked_select(mask) == pad_targets.masked_select(mask)
- )
- denominator = torch.sum(mask)
- return float(numerator) / float(denominator)
def to_torch_tensor(x):
--
Gitblit v1.9.1