From 7eaf608c2d4473a77bd1590f93ea9bdbedde346a Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 19 五月 2023 11:31:33 +0800
Subject: [PATCH] Merge pull request #531 from alibaba-damo-academy/dev_new
---
funasr/bin/asr_infer.py | 43 +++++++++++++++++--------------------------
1 files changed, 17 insertions(+), 26 deletions(-)
diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py
index f6c5504..acb5fd8 100644
--- a/funasr/bin/asr_infer.py
+++ b/funasr/bin/asr_infer.py
@@ -488,15 +488,20 @@
nbest_hyps = nbest_hyps[: self.nbest]
else:
- yseq = am_scores.argmax(dim=-1)
- score = am_scores.max(dim=-1)[0]
- score = torch.sum(score, dim=-1)
- # pad with mask tokens to ensure compatibility with sos/eos tokens
- yseq = torch.tensor(
- [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device
- )
+ if pre_token_length[i] == 0:
+ yseq = torch.tensor(
+ [self.asr_model.sos] + [self.asr_model.eos], device=yseq.device
+ )
+ score = torch.tensor(0.0, device=yseq.device)
+ else:
+ yseq = am_scores.argmax(dim=-1)
+ score = am_scores.max(dim=-1)[0]
+ score = torch.sum(score, dim=-1)
+ # pad with mask tokens to ensure compatibility with sos/eos tokens
+ yseq = torch.tensor(
+ [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device
+ )
nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
-
for hyp in nbest_hyps:
assert isinstance(hyp, (Hypothesis)), type(hyp)
@@ -749,10 +754,13 @@
feats = cache_en["feats"]
feats_len = torch.tensor([feats.shape[1]])
self.asr_model.frontend = None
+ self.frontend.cache_reset()
results = self.infer(feats, feats_len, cache)
return results
else:
if self.frontend is not None:
+ if cache_en["start_idx"] == 0:
+ self.frontend.cache_reset()
feats, feats_len = self.frontend.forward(speech, speech_lengths, cache_en["is_final"])
feats = to_device(feats, device=self.device)
feats_len = feats_len.int()
@@ -762,23 +770,6 @@
feats_len = speech_lengths
if feats.shape[1] != 0:
- if cache_en["is_final"]:
- if feats.shape[1] + cache_en["chunk_size"][2] < cache_en["chunk_size"][1]:
- cache_en["last_chunk"] = True
- else:
- # first chunk
- feats_chunk1 = feats[:, :cache_en["chunk_size"][1], :]
- feats_len = torch.tensor([feats_chunk1.shape[1]])
- results_chunk1 = self.infer(feats_chunk1, feats_len, cache)
-
- # last chunk
- cache_en["last_chunk"] = True
- feats_chunk2 = feats[:, -(feats.shape[1] + cache_en["chunk_size"][2] - cache_en["chunk_size"][1]):, :]
- feats_len = torch.tensor([feats_chunk2.shape[1]])
- results_chunk2 = self.infer(feats_chunk2, feats_len, cache)
-
- return [" ".join(results_chunk1 + results_chunk2)]
-
results = self.infer(feats, feats_len, cache)
return results
@@ -1598,7 +1589,7 @@
d = ModelDownloader()
kwargs.update(**d.download_and_unpack(model_tag))
- return Speech2Text(**kwargs)
+ return Speech2TextTransducer(**kwargs)
class Speech2TextSAASR:
--
Gitblit v1.9.1