| | |
| | | from funasr.utils import config_argparse |
| | | from funasr.utils.types import str2bool, str2triple_str, str_or_none |
| | | from funasr.utils.cli_utils import get_commandline_args |
| | | |
| | | from funasr.models.frontend.wav_frontend import WavFrontend |
| | | |
| | | class Speech2Text: |
| | | """Speech2Text class for Transducer models. |
| | |
| | | self, |
| | | asr_train_config: Union[Path, str] = None, |
| | | asr_model_file: Union[Path, str] = None, |
| | | cmvn_file: Union[Path, str] = None, |
| | | beam_search_config: Dict[str, Any] = None, |
| | | lm_train_config: Union[Path, str] = None, |
| | | lm_file: Union[Path, str] = None, |
| | |
| | | super().__init__() |
| | | |
| | | assert check_argument_types() |
| | | |
| | | asr_model, asr_train_args = ASRTransducerTask.build_model_from_file( |
| | | asr_train_config, asr_model_file, device |
| | | asr_train_config, asr_model_file, cmvn_file, device |
| | | ) |
| | | |
| | | frontend = None |
| | | if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None: |
| | | frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf) |
| | | |
| | | if quantize_asr_model: |
| | | if quantize_modules is not None: |
| | |
| | | tokenizer = build_tokenizer(token_type=token_type) |
| | | converter = TokenIDConverter(token_list=token_list) |
| | | logging.info(f"Text tokenizer: {tokenizer}") |
| | | |
| | | |
| | | self.asr_model = asr_model |
| | | self.asr_train_args = asr_train_args |
| | | self.device = device |
| | |
| | | self.simu_streaming = False |
| | | self.asr_model.encoder.dynamic_chunk_training = False |
| | | |
| | | self.n_fft = asr_train_args.frontend_conf.get("n_fft", 512) |
| | | self.hop_length = asr_train_args.frontend_conf.get("hop_length", 128) |
| | | |
| | | if asr_train_args.frontend_conf.get("win_length", None) is not None: |
| | | self.frontend_window_size = asr_train_args.frontend_conf["win_length"] |
| | | else: |
| | | self.frontend_window_size = self.n_fft |
| | | |
| | | self.frontend = frontend |
| | | self.window_size = self.chunk_size + self.right_context |
| | | self._raw_ctx = self.asr_model.encoder.get_encoder_input_raw_size( |
| | | self.window_size, self.hop_length |
| | | ) |
| | | |
| | | self._ctx = self.asr_model.encoder.get_encoder_input_size( |
| | | self.window_size |
| | | ) |
| | | |
| | | |
| | | #self.last_chunk_length = ( |
| | | # self.asr_model.encoder.embed.min_frame_length + self.right_context + 1 |
| | | #) * self.hop_length |
| | |
| | | self.beam_search.reset_inference_cache() |
| | | |
| | | self.num_processed_frames = torch.tensor([[0]], device=self.device) |
| | | |
| | | def apply_frontend( |
| | | self, speech: torch.Tensor, is_final: bool = False |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Forward frontend. |
| | | Args: |
| | | speech: Speech data. (S) |
| | | is_final: Whether speech corresponds to the final (or only) chunk of data. |
| | | Returns: |
| | | feats: Features sequence. (1, T_in, F) |
| | | feats_lengths: Features sequence length. (1, T_in, F) |
| | | """ |
| | | if self.frontend_cache is not None: |
| | | speech = torch.cat([self.frontend_cache["waveform_buffer"], speech], dim=0) |
| | | |
| | | if is_final: |
| | | if self.streaming and speech.size(0) < self.last_chunk_length: |
| | | pad = torch.zeros( |
| | | self.last_chunk_length - speech.size(0), dtype=speech.dtype |
| | | ) |
| | | speech = torch.cat([speech, pad], dim=0) |
| | | |
| | | speech_to_process = speech |
| | | waveform_buffer = None |
| | | else: |
| | | n_frames = ( |
| | | speech.size(0) - (self.frontend_window_size - self.hop_length) |
| | | ) // self.hop_length |
| | | |
| | | n_residual = ( |
| | | speech.size(0) - (self.frontend_window_size - self.hop_length) |
| | | ) % self.hop_length |
| | | |
| | | speech_to_process = speech.narrow( |
| | | 0, |
| | | 0, |
| | | (self.frontend_window_size - self.hop_length) |
| | | + n_frames * self.hop_length, |
| | | ) |
| | | |
| | | waveform_buffer = speech.narrow( |
| | | 0, |
| | | speech.size(0) |
| | | - (self.frontend_window_size - self.hop_length) |
| | | - n_residual, |
| | | (self.frontend_window_size - self.hop_length) + n_residual, |
| | | ).clone() |
| | | |
| | | speech_to_process = speech_to_process.unsqueeze(0).to( |
| | | getattr(torch, self.dtype) |
| | | ) |
| | | lengths = speech_to_process.new_full( |
| | | [1], dtype=torch.long, fill_value=speech_to_process.size(1) |
| | | ) |
| | | batch = {"speech": speech_to_process, "speech_lengths": lengths} |
| | | batch = to_device(batch, device=self.device) |
| | | |
| | | feats, feats_lengths = self.asr_model._extract_feats(**batch) |
| | | if self.asr_model.normalize is not None: |
| | | feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths) |
| | | |
| | | if is_final: |
| | | if self.frontend_cache is None: |
| | | pass |
| | | else: |
| | | feats = feats.narrow( |
| | | 1, |
| | | math.ceil( |
| | | math.ceil(self.frontend_window_size / self.hop_length) / 2 |
| | | ), |
| | | feats.size(1) |
| | | - math.ceil( |
| | | math.ceil(self.frontend_window_size / self.hop_length) / 2 |
| | | ), |
| | | ) |
| | | else: |
| | | if self.frontend_cache is None: |
| | | feats = feats.narrow( |
| | | 1, |
| | | 0, |
| | | feats.size(1) |
| | | - math.ceil( |
| | | math.ceil(self.frontend_window_size / self.hop_length) / 2 |
| | | ), |
| | | ) |
| | | else: |
| | | feats = feats.narrow( |
| | | 1, |
| | | math.ceil( |
| | | math.ceil(self.frontend_window_size / self.hop_length) / 2 |
| | | ), |
| | | feats.size(1) |
| | | - 2 |
| | | * math.ceil( |
| | | math.ceil(self.frontend_window_size / self.hop_length) / 2 |
| | | ), |
| | | ) |
| | | |
| | | feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1)) |
| | | |
| | | if is_final: |
| | | self.frontend_cache = None |
| | | else: |
| | | self.frontend_cache = {"waveform_buffer": waveform_buffer} |
| | | |
| | | return feats, feats_lengths |
| | | |
| | | @torch.no_grad() |
| | | def streaming_decode( |
| | |
| | | if isinstance(speech, np.ndarray): |
| | | speech = torch.tensor(speech) |
| | | |
| | | # lengths: (1,) |
| | | # feats, feats_length = self.apply_frontend(speech) |
| | | feats = speech.unsqueeze(0).to(getattr(torch, self.dtype)) |
| | | # lengths: (1,) |
| | | feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1)) |
| | | |
| | | # print(feats.shape) |
| | | # print(feats_lengths) |
| | | if self.asr_model.normalize is not None: |
| | | feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths) |
| | | |
| | |
| | | data_path_and_name_and_type: Sequence[Tuple[str, str, str]], |
| | | asr_train_config: Optional[str], |
| | | asr_model_file: Optional[str], |
| | | cmvn_file: Optional[str], |
| | | beam_search_config: Optional[dict], |
| | | lm_train_config: Optional[str], |
| | | lm_file: Optional[str], |
| | |
| | | device = "cuda" |
| | | else: |
| | | device = "cpu" |
| | | |
| | | # 1. Set random-seed |
| | | set_all_random_seed(seed) |
| | | |
| | |
| | | speech2text_kwargs = dict( |
| | | asr_train_config=asr_train_config, |
| | | asr_model_file=asr_model_file, |
| | | cmvn_file=cmvn_file, |
| | | beam_search_config=beam_search_config, |
| | | lm_train_config=lm_train_config, |
| | | lm_file=lm_file, |
| | |
| | | help="ASR model parameter file", |
| | | ) |
| | | group.add_argument( |
| | | "--cmvn_file", |
| | | type=str, |
| | | help="Global cmvn file", |
| | | ) |
| | | group.add_argument( |
| | | "--lm_train_config", |
| | | type=str, |
| | | help="LM training configuration", |