aky15
2023-03-15 96bae0153cb04c82d6e7ca7cb9654d55eb987567
rnnt bug fix
3个文件已修改
156 ■■■■ 已修改文件
funasr/bin/asr_inference_rnnt.py 145 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models_transducer/encoder/blocks/conv_input.py 9 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/abs_task.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_rnnt.py
@@ -31,7 +31,7 @@
from funasr.utils import config_argparse
from funasr.utils.types import str2bool, str2triple_str, str_or_none
from funasr.utils.cli_utils import get_commandline_args
from funasr.models.frontend.wav_frontend import WavFrontend
class Speech2Text:
    """Speech2Text class for Transducer models.
@@ -62,6 +62,7 @@
        self,
        asr_train_config: Union[Path, str] = None,
        asr_model_file: Union[Path, str] = None,
        cmvn_file: Union[Path, str] = None,
        beam_search_config: Dict[str, Any] = None,
        lm_train_config: Union[Path, str] = None,
        lm_file: Union[Path, str] = None,
@@ -86,10 +87,13 @@
        super().__init__()
        assert check_argument_types()
        asr_model, asr_train_args = ASRTransducerTask.build_model_from_file(
            asr_train_config, asr_model_file, device
            asr_train_config, asr_model_file, cmvn_file, device
        )
        frontend = None
        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
            frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
        if quantize_asr_model:
            if quantize_modules is not None:
@@ -156,7 +160,7 @@
            tokenizer = build_tokenizer(token_type=token_type)
        converter = TokenIDConverter(token_list=token_list)
        logging.info(f"Text tokenizer: {tokenizer}")
        self.asr_model = asr_model
        self.asr_train_args = asr_train_args
        self.device = device
@@ -181,23 +185,13 @@
            self.simu_streaming = False
            self.asr_model.encoder.dynamic_chunk_training = False
        self.n_fft = asr_train_args.frontend_conf.get("n_fft", 512)
        self.hop_length = asr_train_args.frontend_conf.get("hop_length", 128)
        if asr_train_args.frontend_conf.get("win_length", None) is not None:
            self.frontend_window_size = asr_train_args.frontend_conf["win_length"]
        else:
            self.frontend_window_size = self.n_fft
        self.frontend = frontend
        self.window_size = self.chunk_size + self.right_context
        self._raw_ctx = self.asr_model.encoder.get_encoder_input_raw_size(
            self.window_size, self.hop_length
        )
        self._ctx = self.asr_model.encoder.get_encoder_input_size(
            self.window_size
        )
       
        #self.last_chunk_length = (
        #    self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
        #) * self.hop_length
@@ -217,112 +211,6 @@
        self.beam_search.reset_inference_cache()
        self.num_processed_frames = torch.tensor([[0]], device=self.device)
    def apply_frontend(
        self, speech: torch.Tensor, is_final: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward frontend.
        Args:
            speech: Speech data. (S)
            is_final: Whether speech corresponds to the final (or only) chunk of data.
        Returns:
            feats: Features sequence. (1, T_in, F)
            feats_lengths: Features sequence length. (1, T_in, F)
        """
        if self.frontend_cache is not None:
            speech = torch.cat([self.frontend_cache["waveform_buffer"], speech], dim=0)
        if is_final:
            if self.streaming and speech.size(0) < self.last_chunk_length:
                pad = torch.zeros(
                    self.last_chunk_length - speech.size(0), dtype=speech.dtype
                )
                speech = torch.cat([speech, pad], dim=0)
            speech_to_process = speech
            waveform_buffer = None
        else:
            n_frames = (
                speech.size(0) - (self.frontend_window_size - self.hop_length)
            ) // self.hop_length
            n_residual = (
                speech.size(0) - (self.frontend_window_size - self.hop_length)
            ) % self.hop_length
            speech_to_process = speech.narrow(
                0,
                0,
                (self.frontend_window_size - self.hop_length)
                + n_frames * self.hop_length,
            )
            waveform_buffer = speech.narrow(
                0,
                speech.size(0)
                - (self.frontend_window_size - self.hop_length)
                - n_residual,
                (self.frontend_window_size - self.hop_length) + n_residual,
            ).clone()
        speech_to_process = speech_to_process.unsqueeze(0).to(
            getattr(torch, self.dtype)
        )
        lengths = speech_to_process.new_full(
            [1], dtype=torch.long, fill_value=speech_to_process.size(1)
        )
        batch = {"speech": speech_to_process, "speech_lengths": lengths}
        batch = to_device(batch, device=self.device)
        feats, feats_lengths = self.asr_model._extract_feats(**batch)
        if self.asr_model.normalize is not None:
            feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
        if is_final:
            if self.frontend_cache is None:
                pass
            else:
                feats = feats.narrow(
                    1,
                    math.ceil(
                        math.ceil(self.frontend_window_size / self.hop_length) / 2
                    ),
                    feats.size(1)
                    - math.ceil(
                        math.ceil(self.frontend_window_size / self.hop_length) / 2
                    ),
                )
        else:
            if self.frontend_cache is None:
                feats = feats.narrow(
                    1,
                    0,
                    feats.size(1)
                    - math.ceil(
                        math.ceil(self.frontend_window_size / self.hop_length) / 2
                    ),
                )
            else:
                feats = feats.narrow(
                    1,
                    math.ceil(
                        math.ceil(self.frontend_window_size / self.hop_length) / 2
                    ),
                    feats.size(1)
                    - 2
                    * math.ceil(
                        math.ceil(self.frontend_window_size / self.hop_length) / 2
                    ),
                )
        feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
        if is_final:
            self.frontend_cache = None
        else:
            self.frontend_cache = {"waveform_buffer": waveform_buffer}
        return feats, feats_lengths
    @torch.no_grad()
    def streaming_decode(
@@ -410,14 +298,9 @@
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
        
        # lengths: (1,)
        # feats, feats_length = self.apply_frontend(speech)
        feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
        # lengths: (1,)
        feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
        # print(feats.shape)
        # print(feats_lengths)
        if self.asr_model.normalize is not None:
            feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
@@ -495,6 +378,7 @@
    data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
    asr_train_config: Optional[str],
    asr_model_file: Optional[str],
    cmvn_file: Optional[str],
    beam_search_config: Optional[dict],
    lm_train_config: Optional[str],
    lm_file: Optional[str],
@@ -562,7 +446,6 @@
        device = "cuda"
    else:
        device = "cpu"
    # 1. Set random-seed
    set_all_random_seed(seed)
@@ -570,6 +453,7 @@
    speech2text_kwargs = dict(
        asr_train_config=asr_train_config,
        asr_model_file=asr_model_file,
        cmvn_file=cmvn_file,
        beam_search_config=beam_search_config,
        lm_train_config=lm_train_config,
        lm_file=lm_file,
@@ -720,6 +604,11 @@
        help="ASR model parameter file",
    )
    group.add_argument(
        "--cmvn_file",
        type=str,
        help="Global cmvn file",
    )
    group.add_argument(
        "--lm_train_config",
        type=str,
        help="LM training configuration",
funasr/models_transducer/encoder/blocks/conv_input.py
@@ -120,7 +120,7 @@
                self.create_new_mask = self.create_new_conv2d_mask
        self.vgg_like = vgg_like
        self.min_frame_length = 2
        self.min_frame_length = 7
        if output_size is not None:
            self.output = torch.nn.Linear(output_proj, output_size)
@@ -218,9 +218,4 @@
            : Number of frames before subsampling.
        """
        if self.subsampling_factor > 1:
            if self.vgg_like:
                return ((size * 2) * self.stride_1) + 1
            return ((size + 2) * 2) + (self.kernel_2 - 1) * self.stride_2
        return size
        return size * self.subsampling_factor
funasr/tasks/abs_task.py
@@ -1576,7 +1576,7 @@
            preprocess=iter_options.preprocess_fn,
            max_cache_size=iter_options.max_cache_size,
            max_cache_fd=iter_options.max_cache_fd,
            dest_sample_rate=args.frontend_conf["fs"],
            dest_sample_rate=args.frontend_conf["fs"] if args.frontend_conf else 16000,
        )
        cls.check_task_requirements(
            dataset, args.allow_variable_data_keys, train=iter_options.train