From 33a9e08dc9b65abc3f3e18d14253f95c79e0f749 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 08 五月 2024 19:20:43 +0800
Subject: [PATCH] dynamic batch
---
funasr/models/sense_voice/search.py | 9 ++++-----
1 files changed, 4 insertions(+), 5 deletions(-)
diff --git a/funasr/models/sense_voice/search.py b/funasr/models/sense_voice/search.py
index 4400ce7..3a1a049 100644
--- a/funasr/models/sense_voice/search.py
+++ b/funasr/models/sense_voice/search.py
@@ -54,7 +54,6 @@
event_bg_token: List[int] = field(default_factory=lambda: [58946, 58948, 58950, 58952]),
event_ed_token: List[int] = field(default_factory=lambda: [58947, 58949, 58951, 58953]),
event_score_ga: List[float] = field(default_factory=lambda: [1, 1, 5, 25]),
-
token_list: List[str] = None,
pre_beam_ratio: float = 1.5,
pre_beam_score_key: str = None,
@@ -209,15 +208,16 @@
last_token = yseq[-1]
if last_token in self.emo_tokens + [self.emo_unk]:
- # prevent output event after emotation token
+ # prevent output event after emotation token
score[self.event_bg_token] = -np.inf
- for eve_bg, eve_ed, eve_ga in zip(self.event_bg_token, self.event_ed_token, self.event_score_ga):
+ for eve_bg, eve_ed, eve_ga in zip(
+ self.event_bg_token, self.event_ed_token, self.event_score_ga
+ ):
score_offset = get_score(yseq, eve_bg, eve_ed)
score[eve_bg] += score_offset[0]
score[eve_ed] += score_offset[1]
score[eve_bg] += math.log(eve_ga)
-
score[self.emo_unk] += math.log(self.emo_unk_score)
for emo, emo_th in zip(self.emo_tokens, self.emo_scores):
@@ -231,7 +231,6 @@
scores[k] = struct_score(hyp.yseq, scores[k])
return scores, states
-
def score_partial(
self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor
--
Gitblit v1.9.1