From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/models/transducer/beam_search_transducer.py | 44 +++++++++++++++++---------------------------
1 files changed, 17 insertions(+), 27 deletions(-)
diff --git a/funasr/models/transducer/beam_search_transducer.py b/funasr/models/transducer/beam_search_transducer.py
index 04b26b3..aa9e34e 100644
--- a/funasr/models/transducer/beam_search_transducer.py
+++ b/funasr/models/transducer/beam_search_transducer.py
@@ -1,10 +1,12 @@
-"""Search algorithms for Transducer models."""
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+import torch
+import numpy as np
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
-
-import numpy as np
-import torch
from funasr.models.transducer.joint_network import JointNetwork
@@ -90,21 +92,18 @@
self.vocab_size = decoder.vocab_size
- assert beam_size <= self.vocab_size, (
- "beam_size (%d) should be smaller than or equal to vocabulary size (%d)."
- % (
- beam_size,
- self.vocab_size,
- )
+ assert (
+ beam_size <= self.vocab_size
+ ), "beam_size (%d) should be smaller than or equal to vocabulary size (%d)." % (
+ beam_size,
+ self.vocab_size,
)
self.beam_size = beam_size
if search_type == "default":
self.search_algorithm = self.default_beam_search
elif search_type == "tsd":
- assert max_sym_exp > 1, "max_sym_exp (%d) should be greater than one." % (
- max_sym_exp
- )
+ assert max_sym_exp > 1, "max_sym_exp (%d) should be greater than one." % (max_sym_exp)
self.max_sym_exp = max_sym_exp
self.search_algorithm = self.time_sync_decoding
@@ -128,9 +127,7 @@
self.search_algorithm = self.modified_adaptive_expansion_search
else:
- raise NotImplementedError(
- "Specified search type (%s) is not supported." % search_type
- )
+ raise NotImplementedError("Specified search type (%s) is not supported." % search_type)
self.use_lm = lm is not None
@@ -242,17 +239,12 @@
k_expansions = []
for i, hyp in enumerate(hyps):
- hyp_i = [
- (int(k), hyp.score + float(v))
- for k, v in zip(topk_idx[i], topk_logp[i])
- ]
+ hyp_i = [(int(k), hyp.score + float(v)) for k, v in zip(topk_idx[i], topk_logp[i])]
k_best_exp = max(hyp_i, key=lambda x: x[1])[1]
k_expansions.append(
sorted(
- filter(
- lambda x: (k_best_exp - self.expansion_gamma) <= x[1], hyp_i
- ),
+ filter(lambda x: (k_best_exp - self.expansion_gamma) <= x[1], hyp_i),
key=lambda x: x[1],
reverse=True,
)
@@ -340,9 +332,7 @@
if self.use_lm:
lm_scores, lm_state = self.lm.score(
- torch.LongTensor(
- [self.sos] + max_hyp.yseq[1:], device=self.decoder.device
- ),
+ torch.LongTensor([self.sos] + max_hyp.yseq[1:], device=self.decoder.device),
max_hyp.lm_state,
None,
)
@@ -374,7 +364,7 @@
break
return kept_hyps
-
+
def align_length_sync_decoding(
self,
enc_out: torch.Tensor,
--
Gitblit v1.9.1