From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/models/transducer/beam_search_transducer.py |   44 +++++++++++++++++---------------------------
 1 files changed, 17 insertions(+), 27 deletions(-)

diff --git a/funasr/models/transducer/beam_search_transducer.py b/funasr/models/transducer/beam_search_transducer.py
index 04b26b3..aa9e34e 100644
--- a/funasr/models/transducer/beam_search_transducer.py
+++ b/funasr/models/transducer/beam_search_transducer.py
@@ -1,10 +1,12 @@
-"""Search algorithms for Transducer models."""
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
 
+import torch
+import numpy as np
 from dataclasses import dataclass
 from typing import Any, Dict, List, Optional, Tuple, Union
-
-import numpy as np
-import torch
 
 from funasr.models.transducer.joint_network import JointNetwork
 
@@ -90,21 +92,18 @@
 
         self.vocab_size = decoder.vocab_size
 
-        assert beam_size <= self.vocab_size, (
-            "beam_size (%d) should be smaller than or equal to vocabulary size (%d)."
-            % (
-                beam_size,
-                self.vocab_size,
-            )
+        assert (
+            beam_size <= self.vocab_size
+        ), "beam_size (%d) should be smaller than or equal to vocabulary size (%d)." % (
+            beam_size,
+            self.vocab_size,
         )
         self.beam_size = beam_size
 
         if search_type == "default":
             self.search_algorithm = self.default_beam_search
         elif search_type == "tsd":
-            assert max_sym_exp > 1, "max_sym_exp (%d) should be greater than one." % (
-                max_sym_exp
-            )
+            assert max_sym_exp > 1, "max_sym_exp (%d) should be greater than one." % (max_sym_exp)
             self.max_sym_exp = max_sym_exp
 
             self.search_algorithm = self.time_sync_decoding
@@ -128,9 +127,7 @@
 
             self.search_algorithm = self.modified_adaptive_expansion_search
         else:
-            raise NotImplementedError(
-                "Specified search type (%s) is not supported." % search_type
-            )
+            raise NotImplementedError("Specified search type (%s) is not supported." % search_type)
 
         self.use_lm = lm is not None
 
@@ -242,17 +239,12 @@
         k_expansions = []
 
         for i, hyp in enumerate(hyps):
-            hyp_i = [
-                (int(k), hyp.score + float(v))
-                for k, v in zip(topk_idx[i], topk_logp[i])
-            ]
+            hyp_i = [(int(k), hyp.score + float(v)) for k, v in zip(topk_idx[i], topk_logp[i])]
             k_best_exp = max(hyp_i, key=lambda x: x[1])[1]
 
             k_expansions.append(
                 sorted(
-                    filter(
-                        lambda x: (k_best_exp - self.expansion_gamma) <= x[1], hyp_i
-                    ),
+                    filter(lambda x: (k_best_exp - self.expansion_gamma) <= x[1], hyp_i),
                     key=lambda x: x[1],
                     reverse=True,
                 )
@@ -340,9 +332,7 @@
 
                 if self.use_lm:
                     lm_scores, lm_state = self.lm.score(
-                        torch.LongTensor(
-                            [self.sos] + max_hyp.yseq[1:], device=self.decoder.device
-                        ),
+                        torch.LongTensor([self.sos] + max_hyp.yseq[1:], device=self.decoder.device),
                         max_hyp.lm_state,
                         None,
                     )
@@ -374,7 +364,7 @@
                     break
 
         return kept_hyps
-    
+
     def align_length_sync_decoding(
         self,
         enc_out: torch.Tensor,

--
Gitblit v1.9.1