From 5f25e809c5acb3a24d9eca942e7540f5ddf6d361 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 05 十一月 2024 16:33:27 +0800
Subject: [PATCH] Update version.txt

---
 funasr/metrics/common.py |   14 +++++---------
 1 files changed, 5 insertions(+), 9 deletions(-)

diff --git a/funasr/metrics/common.py b/funasr/metrics/common.py
index 92f9079..2443e0d 100644
--- a/funasr/metrics/common.py
+++ b/funasr/metrics/common.py
@@ -36,9 +36,9 @@
         hyp_length = i - m
         hyps_same_length = [x for x in ended_hyps if len(x["yseq"]) == hyp_length]
         if len(hyps_same_length) > 0:
-            best_hyp_same_length = sorted(
-                hyps_same_length, key=lambda x: x["score"], reverse=True
-            )[0]
+            best_hyp_same_length = sorted(hyps_same_length, key=lambda x: x["score"], reverse=True)[
+                0
+            ]
             if best_hyp_same_length["score"] - best_hyp["score"] < D_end:
                 count += 1
 
@@ -63,9 +63,7 @@
             trans_json = json.load(f)["utts"]
 
     if lsm_type == "unigram":
-        assert transcript is not None, (
-            "transcript is required for %s label smoothing" % lsm_type
-        )
+        assert transcript is not None, "transcript is required for %s label smoothing" % lsm_type
         labelcount = np.zeros(odim)
         for k, v in trans_json.items():
             ids = np.array([int(n) for n in v["output"][0]["tokenid"].split()])
@@ -108,9 +106,7 @@
     :return:
     """
 
-    def __init__(
-        self, char_list, sym_space, sym_blank, report_cer=False, report_wer=False
-    ):
+    def __init__(self, char_list, sym_space, sym_blank, report_cer=False, report_wer=False):
         """Construct an ErrorCalculator object."""
         super(ErrorCalculator, self).__init__()
 

--
Gitblit v1.9.1