From 96bae0153cb04c82d6e7ca7cb9654d55eb987567 Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期三, 15 三月 2023 17:34:34 +0800
Subject: [PATCH] rnnt bug fix

---
 funasr/bin/asr_inference_rnnt.py                      |  145 +++++------------------------------------------
 funasr/tasks/abs_task.py                              |    2 
 funasr/models_transducer/encoder/blocks/conv_input.py |    9 --
 3 files changed, 20 insertions(+), 136 deletions(-)

diff --git a/funasr/bin/asr_inference_rnnt.py b/funasr/bin/asr_inference_rnnt.py
index f651f11..c8a2916 100644
--- a/funasr/bin/asr_inference_rnnt.py
+++ b/funasr/bin/asr_inference_rnnt.py
@@ -31,7 +31,7 @@
 from funasr.utils import config_argparse
 from funasr.utils.types import str2bool, str2triple_str, str_or_none
 from funasr.utils.cli_utils import get_commandline_args
-
+from funasr.models.frontend.wav_frontend import WavFrontend
 
 class Speech2Text:
     """Speech2Text class for Transducer models.
@@ -62,6 +62,7 @@
         self,
         asr_train_config: Union[Path, str] = None,
         asr_model_file: Union[Path, str] = None,
+        cmvn_file: Union[Path, str] = None,
         beam_search_config: Dict[str, Any] = None,
         lm_train_config: Union[Path, str] = None,
         lm_file: Union[Path, str] = None,
@@ -86,10 +87,13 @@
         super().__init__()
 
         assert check_argument_types()
-
         asr_model, asr_train_args = ASRTransducerTask.build_model_from_file(
-            asr_train_config, asr_model_file, device
+            asr_train_config, asr_model_file, cmvn_file, device
         )
+
+        frontend = None
+        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
+            frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
 
         if quantize_asr_model:
             if quantize_modules is not None:
@@ -156,7 +160,7 @@
             tokenizer = build_tokenizer(token_type=token_type)
         converter = TokenIDConverter(token_list=token_list)
         logging.info(f"Text tokenizer: {tokenizer}")
-
+        
         self.asr_model = asr_model
         self.asr_train_args = asr_train_args
         self.device = device
@@ -181,23 +185,13 @@
             self.simu_streaming = False
             self.asr_model.encoder.dynamic_chunk_training = False
 
-        self.n_fft = asr_train_args.frontend_conf.get("n_fft", 512)
-        self.hop_length = asr_train_args.frontend_conf.get("hop_length", 128)
-
-        if asr_train_args.frontend_conf.get("win_length", None) is not None:
-            self.frontend_window_size = asr_train_args.frontend_conf["win_length"]
-        else:
-            self.frontend_window_size = self.n_fft
-
+        self.frontend = frontend
         self.window_size = self.chunk_size + self.right_context
-        self._raw_ctx = self.asr_model.encoder.get_encoder_input_raw_size(
-            self.window_size, self.hop_length
-        )
+        
         self._ctx = self.asr_model.encoder.get_encoder_input_size(
             self.window_size
         )
        
-
         #self.last_chunk_length = (
         #    self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
         #) * self.hop_length
@@ -217,112 +211,6 @@
         self.beam_search.reset_inference_cache()
 
         self.num_processed_frames = torch.tensor([[0]], device=self.device)
-
-    def apply_frontend(
-        self, speech: torch.Tensor, is_final: bool = False
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        """Forward frontend.
-        Args:
-            speech: Speech data. (S)
-            is_final: Whether speech corresponds to the final (or only) chunk of data.
-        Returns:
-            feats: Features sequence. (1, T_in, F)
-            feats_lengths: Features sequence length. (1, T_in, F)
-        """
-        if self.frontend_cache is not None:
-            speech = torch.cat([self.frontend_cache["waveform_buffer"], speech], dim=0)
-
-        if is_final:
-            if self.streaming and speech.size(0) < self.last_chunk_length:
-                pad = torch.zeros(
-                    self.last_chunk_length - speech.size(0), dtype=speech.dtype
-                )
-                speech = torch.cat([speech, pad], dim=0)
-
-            speech_to_process = speech
-            waveform_buffer = None
-        else:
-            n_frames = (
-                speech.size(0) - (self.frontend_window_size - self.hop_length)
-            ) // self.hop_length
-
-            n_residual = (
-                speech.size(0) - (self.frontend_window_size - self.hop_length)
-            ) % self.hop_length
-
-            speech_to_process = speech.narrow(
-                0,
-                0,
-                (self.frontend_window_size - self.hop_length)
-                + n_frames * self.hop_length,
-            )
-
-            waveform_buffer = speech.narrow(
-                0,
-                speech.size(0)
-                - (self.frontend_window_size - self.hop_length)
-                - n_residual,
-                (self.frontend_window_size - self.hop_length) + n_residual,
-            ).clone()
-
-        speech_to_process = speech_to_process.unsqueeze(0).to(
-            getattr(torch, self.dtype)
-        )
-        lengths = speech_to_process.new_full(
-            [1], dtype=torch.long, fill_value=speech_to_process.size(1)
-        )
-        batch = {"speech": speech_to_process, "speech_lengths": lengths}
-        batch = to_device(batch, device=self.device)
-
-        feats, feats_lengths = self.asr_model._extract_feats(**batch)
-        if self.asr_model.normalize is not None:
-            feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
-
-        if is_final:
-            if self.frontend_cache is None:
-                pass
-            else:
-                feats = feats.narrow(
-                    1,
-                    math.ceil(
-                        math.ceil(self.frontend_window_size / self.hop_length) / 2
-                    ),
-                    feats.size(1)
-                    - math.ceil(
-                        math.ceil(self.frontend_window_size / self.hop_length) / 2
-                    ),
-                )
-        else:
-            if self.frontend_cache is None:
-                feats = feats.narrow(
-                    1,
-                    0,
-                    feats.size(1)
-                    - math.ceil(
-                        math.ceil(self.frontend_window_size / self.hop_length) / 2
-                    ),
-                )
-            else:
-                feats = feats.narrow(
-                    1,
-                    math.ceil(
-                        math.ceil(self.frontend_window_size / self.hop_length) / 2
-                    ),
-                    feats.size(1)
-                    - 2
-                    * math.ceil(
-                        math.ceil(self.frontend_window_size / self.hop_length) / 2
-                    ),
-                )
-
-        feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
-
-        if is_final:
-            self.frontend_cache = None
-        else:
-            self.frontend_cache = {"waveform_buffer": waveform_buffer}
-
-        return feats, feats_lengths
 
     @torch.no_grad()
     def streaming_decode(
@@ -410,14 +298,9 @@
         if isinstance(speech, np.ndarray):
             speech = torch.tensor(speech)
         
-        # lengths: (1,)
-        # feats, feats_length = self.apply_frontend(speech)
         feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
-        # lengths: (1,)
         feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
 
-        # print(feats.shape)
-        # print(feats_lengths)
         if self.asr_model.normalize is not None:
             feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
 
@@ -495,6 +378,7 @@
     data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
     asr_train_config: Optional[str],
     asr_model_file: Optional[str],
+    cmvn_file: Optional[str],
     beam_search_config: Optional[dict],
     lm_train_config: Optional[str],
     lm_file: Optional[str],
@@ -562,7 +446,6 @@
         device = "cuda"
     else:
         device = "cpu"
-
     # 1. Set random-seed
     set_all_random_seed(seed)
 
@@ -570,6 +453,7 @@
     speech2text_kwargs = dict(
         asr_train_config=asr_train_config,
         asr_model_file=asr_model_file,
+        cmvn_file=cmvn_file,
         beam_search_config=beam_search_config,
         lm_train_config=lm_train_config,
         lm_file=lm_file,
@@ -720,6 +604,11 @@
         help="ASR model parameter file",
     )
     group.add_argument(
+        "--cmvn_file",
+        type=str,
+        help="Global cmvn file",
+    )
+    group.add_argument(
         "--lm_train_config",
         type=str,
         help="LM training configuration",
diff --git a/funasr/models_transducer/encoder/blocks/conv_input.py b/funasr/models_transducer/encoder/blocks/conv_input.py
index 931d0f0..c68c73b 100644
--- a/funasr/models_transducer/encoder/blocks/conv_input.py
+++ b/funasr/models_transducer/encoder/blocks/conv_input.py
@@ -120,7 +120,7 @@
                 self.create_new_mask = self.create_new_conv2d_mask
 
         self.vgg_like = vgg_like
-        self.min_frame_length = 2
+        self.min_frame_length = 7
 
         if output_size is not None:
             self.output = torch.nn.Linear(output_proj, output_size)
@@ -218,9 +218,4 @@
             : Number of frames before subsampling.
 
         """
-        if self.subsampling_factor > 1:
-            if self.vgg_like:
-                return ((size * 2) * self.stride_1) + 1
-
-            return ((size + 2) * 2) + (self.kernel_2 - 1) * self.stride_2
-        return size
+        return size * self.subsampling_factor
diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index e0884ce..cc5b708 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -1576,7 +1576,7 @@
             preprocess=iter_options.preprocess_fn,
             max_cache_size=iter_options.max_cache_size,
             max_cache_fd=iter_options.max_cache_fd,
-            dest_sample_rate=args.frontend_conf["fs"],
+            dest_sample_rate=args.frontend_conf["fs"] if args.frontend_conf else 16000,
         )
         cls.check_task_requirements(
             dataset, args.allow_variable_data_keys, train=iter_options.train

--
Gitblit v1.9.1