From da98950b422bd14d2c9357a878c19268b196b9c0 Mon Sep 17 00:00:00 2001
From: 仁迷 <haoneng.lhn@alibaba-inc.com>
Date: 星期四, 29 十二月 2022 14:12:58 +0800
Subject: [PATCH] fix uniasr inference bug

---
 funasr/bin/asr_inference_uniasr.py |    6 +++---
 1 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/funasr/bin/asr_inference_uniasr.py b/funasr/bin/asr_inference_uniasr.py
index 9aea1a3..2e87675 100755
--- a/funasr/bin/asr_inference_uniasr.py
+++ b/funasr/bin/asr_inference_uniasr.py
@@ -215,14 +215,14 @@
         lfr_factor = max(1, (speech.size()[-1] // 80) - 1)
         # lengths: (1,)
         lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
-        speech_raw = speech.clone().to(self.device)
         if self.frontend is not None:
             feats, feats_len = self.frontend.forward(speech, lengths)
             feats = to_device(feats, device=self.device)
             feats_len = feats_len.int()
         else:
-            feats = speech_raw
+            feats = speech
             feats_len = lengths
+        feats_raw = feats.clone().to(self.device)
         batch = {"speech": feats, "speech_lengths": feats_len}
 
         # a. To device
@@ -235,7 +235,7 @@
         if self.decoding_mode == "model1":
             predictor_outs = self.asr_model.calc_predictor_mask(enc, enc_len)
         else:
-            enc, enc_len = self.asr_model.encode2(enc, enc_len, feats, feats_len, ind=self.decoding_ind)
+            enc, enc_len = self.asr_model.encode2(enc, enc_len, feats_raw, feats_len, ind=self.decoding_ind)
             predictor_outs = self.asr_model.calc_predictor_mask2(enc, enc_len)
 
         scama_mask = predictor_outs[4]

--
Gitblit v1.9.1