| | |
| | | from packaging.version import parse as V |
| | | from typeguard import check_argument_types, check_return_type |
| | | |
| | | from funasr.models_transducer.beam_search_transducer import ( |
| | | from funasr.modules.beam_search.beam_search_transducer import ( |
| | | BeamSearchTransducer, |
| | | Hypothesis, |
| | | ) |
| | | from funasr.models_transducer.utils import TooShortUttError |
| | | from funasr.modules.nets_utils import TooShortUttError |
| | | from funasr.fileio.datadir_writer import DatadirWriter |
| | | from funasr.tasks.asr_transducer import ASRTransducerTask |
| | | from funasr.tasks.asr import ASRTransducerTask |
| | | from funasr.tasks.lm import LMTask |
| | | from funasr.text.build_tokenizer import build_tokenizer |
| | | from funasr.text.token_id_converter import TokenIDConverter |
| | |
| | | self.streaming = streaming |
| | | self.simu_streaming = simu_streaming |
| | | self.chunk_size = max(chunk_size, 0) |
| | | self.left_context = max(left_context, 0) |
| | | self.left_context = left_context |
| | | self.right_context = max(right_context, 0) |
| | | |
| | | if not streaming or chunk_size == 0: |
| | |
| | | self.frontend = frontend |
| | | self.window_size = self.chunk_size + self.right_context |
| | | |
| | | self._ctx = self.asr_model.encoder.get_encoder_input_size( |
| | | self.window_size |
| | | ) |
| | | if self.streaming: |
| | | self._ctx = self.asr_model.encoder.get_encoder_input_size( |
| | | self.window_size |
| | | ) |
| | | |
| | | #self.last_chunk_length = ( |
| | | # self.asr_model.encoder.embed.min_frame_length + self.right_context + 1 |
| | | #) * self.hop_length |
| | | |
| | | self.last_chunk_length = ( |
| | | self.asr_model.encoder.embed.min_frame_length + self.right_context + 1 |
| | | ) |
| | | self.reset_inference_cache() |
| | | self.last_chunk_length = ( |
| | | self.asr_model.encoder.embed.min_frame_length + self.right_context + 1 |
| | | ) |
| | | self.reset_inference_cache() |
| | | |
| | | def reset_inference_cache(self) -> None: |
| | | """Reset Speech2Text parameters.""" |
| | |
| | | |
| | | feats = speech.unsqueeze(0).to(getattr(torch, self.dtype)) |
| | | feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1)) |
| | | |
| | | if self.asr_model.normalize is not None: |
| | | feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths) |
| | | |
| | | feats = to_device(feats, device=self.device) |
| | | feats_lengths = to_device(feats_lengths, device=self.device) |