From da98950b422bd14d2c9357a878c19268b196b9c0 Mon Sep 17 00:00:00 2001
From: 仁迷 <haoneng.lhn@alibaba-inc.com>
Date: 星期四, 29 十二月 2022 14:12:58 +0800
Subject: [PATCH] fix uniasr inference bug
---
funasr/bin/asr_inference_uniasr.py | 6 +++---
1 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/funasr/bin/asr_inference_uniasr.py b/funasr/bin/asr_inference_uniasr.py
index 9aea1a3..2e87675 100755
--- a/funasr/bin/asr_inference_uniasr.py
+++ b/funasr/bin/asr_inference_uniasr.py
@@ -215,14 +215,14 @@
lfr_factor = max(1, (speech.size()[-1] // 80) - 1)
# lengths: (1,)
lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
- speech_raw = speech.clone().to(self.device)
if self.frontend is not None:
feats, feats_len = self.frontend.forward(speech, lengths)
feats = to_device(feats, device=self.device)
feats_len = feats_len.int()
else:
- feats = speech_raw
+ feats = speech
feats_len = lengths
+ feats_raw = feats.clone().to(self.device)
batch = {"speech": feats, "speech_lengths": feats_len}
# a. To device
@@ -235,7 +235,7 @@
if self.decoding_mode == "model1":
predictor_outs = self.asr_model.calc_predictor_mask(enc, enc_len)
else:
- enc, enc_len = self.asr_model.encode2(enc, enc_len, feats, feats_len, ind=self.decoding_ind)
+ enc, enc_len = self.asr_model.encode2(enc, enc_len, feats_raw, feats_len, ind=self.decoding_ind)
predictor_outs = self.asr_model.calc_predictor_mask2(enc, enc_len)
scama_mask = predictor_outs[4]
--
Gitblit v1.9.1