From b15db52e4e67da8a133a67e8ffa415386de48b40 Mon Sep 17 00:00:00 2001
From: zhuyunfeng <10596244@qq.com>
Date: 星期二, 09 五月 2023 23:03:15 +0800
Subject: [PATCH] Add contributor

---
 funasr/bin/asr_inference_rnnt.py |   27 ++++++++++++---------------
 1 files changed, 12 insertions(+), 15 deletions(-)

diff --git a/funasr/bin/asr_inference_rnnt.py b/funasr/bin/asr_inference_rnnt.py
index f65bd07..bd36907 100644
--- a/funasr/bin/asr_inference_rnnt.py
+++ b/funasr/bin/asr_inference_rnnt.py
@@ -16,13 +16,13 @@
 from packaging.version import parse as V
 from typeguard import check_argument_types, check_return_type
 
-from funasr.models_transducer.beam_search_transducer import (
+from funasr.modules.beam_search.beam_search_transducer import (
     BeamSearchTransducer,
     Hypothesis,
 )
-from funasr.models_transducer.utils import TooShortUttError
+from funasr.modules.nets_utils import TooShortUttError
 from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.tasks.asr_transducer import ASRTransducerTask
+from funasr.tasks.asr import ASRTransducerTask
 from funasr.tasks.lm import LMTask
 from funasr.text.build_tokenizer import build_tokenizer
 from funasr.text.token_id_converter import TokenIDConverter
@@ -174,7 +174,7 @@
         self.streaming = streaming
         self.simu_streaming = simu_streaming
         self.chunk_size = max(chunk_size, 0)
-        self.left_context = max(left_context, 0)
+        self.left_context = left_context
         self.right_context = max(right_context, 0)
 
         if not streaming or chunk_size == 0:
@@ -188,18 +188,15 @@
         self.frontend = frontend
         self.window_size = self.chunk_size + self.right_context
         
-        self._ctx = self.asr_model.encoder.get_encoder_input_size(
-            self.window_size
-        )
+        if self.streaming:
+            self._ctx = self.asr_model.encoder.get_encoder_input_size(
+                self.window_size
+            )
        
-        #self.last_chunk_length = (
-        #    self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
-        #) * self.hop_length
-
-        self.last_chunk_length = (
-            self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
-        )
-        self.reset_inference_cache()
+            self.last_chunk_length = (
+                self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
+            )
+            self.reset_inference_cache()
 
     def reset_inference_cache(self) -> None:
         """Reset Speech2Text parameters."""

--
Gitblit v1.9.1