From 9633e64bb1d7aef3ea49fe1e4ed3d7fab838b52e Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 06 六月 2024 18:46:29 +0800
Subject: [PATCH] auto frontend
---
funasr/metrics/common.py | 14 +++++---------
1 files changed, 5 insertions(+), 9 deletions(-)
diff --git a/funasr/metrics/common.py b/funasr/metrics/common.py
index 92f9079..2443e0d 100644
--- a/funasr/metrics/common.py
+++ b/funasr/metrics/common.py
@@ -36,9 +36,9 @@
hyp_length = i - m
hyps_same_length = [x for x in ended_hyps if len(x["yseq"]) == hyp_length]
if len(hyps_same_length) > 0:
- best_hyp_same_length = sorted(
- hyps_same_length, key=lambda x: x["score"], reverse=True
- )[0]
+ best_hyp_same_length = sorted(hyps_same_length, key=lambda x: x["score"], reverse=True)[
+ 0
+ ]
if best_hyp_same_length["score"] - best_hyp["score"] < D_end:
count += 1
@@ -63,9 +63,7 @@
trans_json = json.load(f)["utts"]
if lsm_type == "unigram":
- assert transcript is not None, (
- "transcript is required for %s label smoothing" % lsm_type
- )
+ assert transcript is not None, "transcript is required for %s label smoothing" % lsm_type
labelcount = np.zeros(odim)
for k, v in trans_json.items():
ids = np.array([int(n) for n in v["output"][0]["tokenid"].split()])
@@ -108,9 +106,7 @@
:return:
"""
- def __init__(
- self, char_list, sym_space, sym_blank, report_cer=False, report_wer=False
- ):
+ def __init__(self, char_list, sym_space, sym_blank, report_cer=False, report_wer=False):
"""Construct an ErrorCalculator object."""
super(ErrorCalculator, self).__init__()
--
Gitblit v1.9.1