From 3c3754dcc7568e76fa7d4b2c4e14849f68cc6ee7 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期日, 28 五月 2023 23:46:01 +0800
Subject: [PATCH] update repo
---
funasr/bin/asr_infer.py | 40 ++++++++++++++++++++++++++++------------
1 files changed, 28 insertions(+), 12 deletions(-)
diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py
index 03145f8..9da7ef7 100644
--- a/funasr/bin/asr_infer.py
+++ b/funasr/bin/asr_infer.py
@@ -9,6 +9,7 @@
import time
import copy
import os
+import re
import codecs
import tempfile
import requests
@@ -488,15 +489,20 @@
nbest_hyps = nbest_hyps[: self.nbest]
else:
- yseq = am_scores.argmax(dim=-1)
- score = am_scores.max(dim=-1)[0]
- score = torch.sum(score, dim=-1)
- # pad with mask tokens to ensure compatibility with sos/eos tokens
- yseq = torch.tensor(
- [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device
- )
+ if pre_token_length[i] == 0:
+ yseq = torch.tensor(
+ [self.asr_model.sos] + [self.asr_model.eos], device=yseq.device
+ )
+ score = torch.tensor(0.0, device=yseq.device)
+ else:
+ yseq = am_scores.argmax(dim=-1)
+ score = am_scores.max(dim=-1)[0]
+ score = torch.sum(score, dim=-1)
+ # pad with mask tokens to ensure compatibility with sos/eos tokens
+ yseq = torch.tensor(
+ [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device
+ )
nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
-
for hyp in nbest_hyps:
assert isinstance(hyp, (Hypothesis)), type(hyp)
@@ -749,10 +755,13 @@
feats = cache_en["feats"]
feats_len = torch.tensor([feats.shape[1]])
self.asr_model.frontend = None
+ self.frontend.cache_reset()
results = self.infer(feats, feats_len, cache)
return results
else:
if self.frontend is not None:
+ if cache_en["start_idx"] == 0:
+ self.frontend.cache_reset()
feats, feats_len = self.frontend.forward(speech, speech_lengths, cache_en["is_final"])
feats = to_device(feats, device=self.device)
feats_len = feats_len.int()
@@ -820,9 +829,16 @@
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
- token = " ".join(token)
-
- results.append(token)
+ postprocessed_result = ""
+ for item in token:
+ if item.endswith('@@'):
+ postprocessed_result += item[:-2]
+ elif re.match('^[a-zA-Z]+$', item):
+ postprocessed_result += item + " "
+ else:
+ postprocessed_result += item
+
+ results.append(postprocessed_result)
# assert check_return_type(results)
return results
@@ -1581,7 +1597,7 @@
d = ModelDownloader()
kwargs.update(**d.download_and_unpack(model_tag))
- return Speech2Text(**kwargs)
+ return Speech2TextTransducer(**kwargs)
class Speech2TextSAASR:
--
Gitblit v1.9.1