From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/models/transducer/beam_search_transducer.py |   34 +++++++++++-----------------------
 1 files changed, 11 insertions(+), 23 deletions(-)

diff --git a/funasr/models/transducer/beam_search_transducer.py b/funasr/models/transducer/beam_search_transducer.py
index f599615..aa9e34e 100644
--- a/funasr/models/transducer/beam_search_transducer.py
+++ b/funasr/models/transducer/beam_search_transducer.py
@@ -92,21 +92,18 @@
 
         self.vocab_size = decoder.vocab_size
 
-        assert beam_size <= self.vocab_size, (
-            "beam_size (%d) should be smaller than or equal to vocabulary size (%d)."
-            % (
-                beam_size,
-                self.vocab_size,
-            )
+        assert (
+            beam_size <= self.vocab_size
+        ), "beam_size (%d) should be smaller than or equal to vocabulary size (%d)." % (
+            beam_size,
+            self.vocab_size,
         )
         self.beam_size = beam_size
 
         if search_type == "default":
             self.search_algorithm = self.default_beam_search
         elif search_type == "tsd":
-            assert max_sym_exp > 1, "max_sym_exp (%d) should be greater than one." % (
-                max_sym_exp
-            )
+            assert max_sym_exp > 1, "max_sym_exp (%d) should be greater than one." % (max_sym_exp)
             self.max_sym_exp = max_sym_exp
 
             self.search_algorithm = self.time_sync_decoding
@@ -130,9 +127,7 @@
 
             self.search_algorithm = self.modified_adaptive_expansion_search
         else:
-            raise NotImplementedError(
-                "Specified search type (%s) is not supported." % search_type
-            )
+            raise NotImplementedError("Specified search type (%s) is not supported." % search_type)
 
         self.use_lm = lm is not None
 
@@ -244,17 +239,12 @@
         k_expansions = []
 
         for i, hyp in enumerate(hyps):
-            hyp_i = [
-                (int(k), hyp.score + float(v))
-                for k, v in zip(topk_idx[i], topk_logp[i])
-            ]
+            hyp_i = [(int(k), hyp.score + float(v)) for k, v in zip(topk_idx[i], topk_logp[i])]
             k_best_exp = max(hyp_i, key=lambda x: x[1])[1]
 
             k_expansions.append(
                 sorted(
-                    filter(
-                        lambda x: (k_best_exp - self.expansion_gamma) <= x[1], hyp_i
-                    ),
+                    filter(lambda x: (k_best_exp - self.expansion_gamma) <= x[1], hyp_i),
                     key=lambda x: x[1],
                     reverse=True,
                 )
@@ -342,9 +332,7 @@
 
                 if self.use_lm:
                     lm_scores, lm_state = self.lm.score(
-                        torch.LongTensor(
-                            [self.sos] + max_hyp.yseq[1:], device=self.decoder.device
-                        ),
+                        torch.LongTensor([self.sos] + max_hyp.yseq[1:], device=self.decoder.device),
                         max_hyp.lm_state,
                         None,
                     )
@@ -376,7 +364,7 @@
                     break
 
         return kept_hyps
-    
+
     def align_length_sync_decoding(
         self,
         enc_out: torch.Tensor,

--
Gitblit v1.9.1