Funasr1.0 (#1277)
* funasr1.0 funetine
* funasr1.0 pbar
* update with main (#1260)
* Update websocket_protocol_zh.md
* update
---------
Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com>
Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>
* update with main (#1264)
* Funasr1.0 (#1261)
* funasr1.0 funetine
* funasr1.0 pbar
* update with main (#1260)
* Update websocket_protocol_zh.md
* update
---------
Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com>
Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>
---------
Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com>
Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>
* bug fix
---------
Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com>
Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>
* funasr1.0 sanm scama
* funasr1.0 infer_after_finetune
* funasr1.0 fsmn-vad bug fix
* funasr1.0 fsmn-vad bug fix
* funasr1.0 fsmn-vad bug fix
* funasr1.0 finetune
* funasr1.0 finetune
* funasr1.0 finetune
* funasr1.0 finetune
---------
Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com>
Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>
| | |
| | | self.punc_kwargs = punc_kwargs |
| | | self.spk_model = spk_model |
| | | self.spk_kwargs = spk_kwargs |
| | | self.model_path = kwargs.get("model_path", "./") |
| | | self.model_path = kwargs.get("model_path") |
| | | |
| | | |
| | | |
| | | def build_model(self, **kwargs): |
| | |
| | | data_src = load_audio_text_image_video(source, fs=self.fs) |
| | | if self.preprocessor_speech: |
| | | data_src = self.preprocessor_speech(data_src) |
| | | speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend) # speech: [b, T, d] |
| | | speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend, is_final=True) # speech: [b, T, d] |
| | | |
| | | target = item["target"] |
| | | if self.preprocessor_text: |
| | |
| | | return feats_pad, feats_lens, lfr_splice_frame_idxs |
| | | |
| | | def forward( |
| | | self, input: torch.Tensor, input_lengths: torch.Tensor, cache: dict = {}, **kwargs |
| | | self, input: torch.Tensor, input_lengths: torch.Tensor, **kwargs |
| | | ): |
| | | is_final = kwargs.get("is_final", False) |
| | | cache = kwargs.get("cache", {}) |
| | | if len(cache) == 0: |
| | | self.init_cache(cache) |
| | | |
| | |
| | | from typing import List, Tuple, Dict, Any, Optional |
| | | |
| | | from funasr.utils.datadir_writer import DatadirWriter |
| | | from funasr.utils.load_utils import load_audio_text_image_video,extract_fbank |
| | | from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank |
| | | |
| | | |
| | | class VadStateMachine(Enum): |
| | |
| | | kVadInStateInSpeechSegment = 2 |
| | | kVadInStateEndPointDetected = 3 |
| | | |
| | | |
| | | class FrameState(Enum): |
| | | kFrameStateInvalid = -1 |
| | | kFrameStateSpeech = 1 |
| | | kFrameStateSil = 0 |
| | | |
| | | |
| | | # final voice/unvoice state per frame |
| | | class AudioChangeState(Enum): |
| | |
| | | kChangeStateNoBegin = 4 |
| | | kChangeStateInvalid = 5 |
| | | |
| | | |
| | | class VadDetectMode(Enum): |
| | | kVadSingleUtteranceDetectMode = 0 |
| | | kVadMutipleUtteranceDetectMode = 1 |
| | | |
| | | |
| | | class VADXOptions: |
| | | """ |
| | |
| | | Deep-FSMN for Large Vocabulary Continuous Speech Recognition |
| | | https://arxiv.org/abs/1803.05030 |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | sample_rate: int = 16000, |
| | |
| | | Deep-FSMN for Large Vocabulary Continuous Speech Recognition |
| | | https://arxiv.org/abs/1803.05030 |
| | | """ |
| | | |
| | | def __init__(self): |
| | | self.start_ms = 0 |
| | | self.end_ms = 0 |
| | |
| | | Deep-FSMN for Large Vocabulary Continuous Speech Recognition |
| | | https://arxiv.org/abs/1803.05030 |
| | | """ |
| | | |
| | | def __init__(self): |
| | | self.noise_prob = 0.0 |
| | | self.speech_prob = 0.0 |
| | |
| | | Deep-FSMN for Large Vocabulary Continuous Speech Recognition |
| | | https://arxiv.org/abs/1803.05030 |
| | | """ |
| | | |
| | | def __init__(self, window_size_ms: int, |
| | | sil_to_speech_time: int, |
| | | speech_to_sil_time: int, |
| | |
| | | def GetWinSize(self) -> int: |
| | | return int(self.win_size_frame) |
| | | |
| | | def DetectOneFrame(self, frameState: FrameState, frame_count: int, cache: dict={}) -> AudioChangeState: |
| | | def DetectOneFrame(self, frameState: FrameState, frame_count: int, cache: dict = {}) -> AudioChangeState: |
| | | cur_frame_state = FrameState.kFrameStateSil |
| | | if frameState == FrameState.kFrameStateSpeech: |
| | | cur_frame_state = 1 |
| | |
| | | def FrameSizeMs(self) -> int: |
| | | return int(self.frame_size_ms) |
| | | |
| | | |
| | | class Stats(object): |
| | | def __init__(self, |
| | | sil_pdf_ids, |
| | | max_end_sil_frame_cnt_thresh, |
| | | speech_noise_thres, |
| | | ): |
| | | |
| | | self.data_buf_start_frame = 0 |
| | | self.frm_cnt = 0 |
| | | self.latest_confirmed_speech_frame = 0 |
| | |
| | | self.waveform = None |
| | | self.last_drop_frames = 0 |
| | | |
| | | |
| | | @tables.register("model_classes", "FsmnVADStreaming") |
| | | class FsmnVADStreaming(nn.Module): |
| | | """ |
| | |
| | | Deep-FSMN for Large Vocabulary Continuous Speech Recognition |
| | | https://arxiv.org/abs/1803.05030 |
| | | """ |
| | | |
| | | def __init__(self, |
| | | encoder: str = None, |
| | | encoder_conf: Optional[Dict] = None, |
| | |
| | | encoder_class = tables.encoder_classes.get(encoder) |
| | | encoder = encoder_class(**encoder_conf) |
| | | self.encoder = encoder |
| | | |
| | | |
| | | def ResetDetection(self, cache: dict = {}): |
| | | cache["stats"].continous_silence_frame_count = 0 |
| | |
| | | drop_frames = int(cache["stats"].output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms) |
| | | real_drop_frames = drop_frames - cache["stats"].last_drop_frames |
| | | cache["stats"].last_drop_frames = drop_frames |
| | | cache["stats"].data_buf_all = cache["stats"].data_buf_all[real_drop_frames * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):] |
| | | cache["stats"].data_buf_all = cache["stats"].data_buf_all[real_drop_frames * int( |
| | | self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):] |
| | | cache["stats"].decibel = cache["stats"].decibel[real_drop_frames:] |
| | | cache["stats"].scores = cache["stats"].scores[:, real_drop_frames:, :] |
| | | |
| | |
| | | frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000) |
| | | frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) |
| | | if cache["stats"].data_buf_all is None: |
| | | cache["stats"].data_buf_all = cache["stats"].waveform[0] # cache["stats"].data_buf is pointed to cache["stats"].waveform[0] |
| | | cache["stats"].data_buf_all = cache["stats"].waveform[ |
| | | 0] # cache["stats"].data_buf is pointed to cache["stats"].waveform[0] |
| | | cache["stats"].data_buf = cache["stats"].data_buf_all |
| | | else: |
| | | cache["stats"].data_buf_all = torch.cat((cache["stats"].data_buf_all, cache["stats"].waveform[0])) |
| | |
| | | else: |
| | | cache["stats"].scores = torch.cat((cache["stats"].scores, scores), dim=1) |
| | | |
| | | def PopDataBufTillFrame(self, frame_idx: int, cache: dict={}) -> None: # need check again |
| | | def PopDataBufTillFrame(self, frame_idx: int, cache: dict = {}) -> None: # need check again |
| | | while cache["stats"].data_buf_start_frame < frame_idx: |
| | | if len(cache["stats"].data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000): |
| | | cache["stats"].data_buf_start_frame += 1 |
| | | cache["stats"].data_buf = cache["stats"].data_buf_all[(cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames) * int( |
| | | self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):] |
| | | cache["stats"].data_buf = cache["stats"].data_buf_all[ |
| | | (cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames) * int( |
| | | self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):] |
| | | |
| | | def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool, |
| | | last_frm_is_end_point: bool, end_point_is_sent_end: bool, cache: dict={}) -> None: |
| | | last_frm_is_end_point: bool, end_point_is_sent_end: bool, cache: dict = {}) -> None: |
| | | self.PopDataBufTillFrame(start_frm, cache=cache) |
| | | expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000) |
| | | if last_frm_is_end_point: |
| | |
| | | cache["stats"].lastest_confirmed_silence_frame = valid_frame |
| | | if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: |
| | | self.PopDataBufTillFrame(valid_frame, cache=cache) |
| | | # silence_detected_callback_ |
| | | # pass |
| | | |
| | | def OnVoiceDetected(self, valid_frame: int, cache:dict={}) -> None: |
| | | # silence_detected_callback_ |
| | | # pass |
| | | |
| | | def OnVoiceDetected(self, valid_frame: int, cache: dict = {}) -> None: |
| | | cache["stats"].latest_confirmed_speech_frame = valid_frame |
| | | self.PopDataToOutputBuf(valid_frame, 1, False, False, False, cache=cache) |
| | | |
| | | def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache:dict={}) -> None: |
| | | def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache: dict = {}) -> None: |
| | | if self.vad_opts.do_start_point_detection: |
| | | pass |
| | | if cache["stats"].confirmed_start_frame != -1: |
| | |
| | | if not fake_result and cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: |
| | | self.PopDataToOutputBuf(cache["stats"].confirmed_start_frame, 1, True, False, False, cache=cache) |
| | | |
| | | def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool, cache:dict={}) -> None: |
| | | def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool, cache: dict = {}) -> None: |
| | | for t in range(cache["stats"].latest_confirmed_speech_frame + 1, end_frame): |
| | | self.OnVoiceDetected(t, cache=cache) |
| | | if self.vad_opts.do_end_point_detection: |
| | |
| | | segment_batch = [] |
| | | if len(cache["stats"].output_data_buf) > 0: |
| | | for i in range(cache["stats"].output_data_buf_offset, len(cache["stats"].output_data_buf)): |
| | | if not is_final and (not cache["stats"].output_data_buf[i].contain_seg_start_point or not cache["stats"].output_data_buf[ |
| | | if not is_final and (not cache["stats"].output_data_buf[i].contain_seg_start_point or not |
| | | cache["stats"].output_data_buf[ |
| | | i].contain_seg_end_point): |
| | | continue |
| | | segment = [cache["stats"].output_data_buf[i].start_ms, cache["stats"].output_data_buf[i].end_ms] |
| | |
| | | # # reset class variables and clear the dict for the next query |
| | | # self.AllResetDetection() |
| | | return segments |
| | | |
| | | |
| | | def init_cache(self, cache: dict = {}, **kwargs): |
| | | |
| | | |
| | | cache["frontend"] = {} |
| | | cache["prev_samples"] = torch.empty(0) |
| | | cache["encoder"] = {} |
| | |
| | | cache: dict = {}, |
| | | **kwargs, |
| | | ): |
| | | |
| | | |
| | | if len(cache) == 0: |
| | | self.init_cache(cache, **kwargs) |
| | | |
| | | meta_data = {} |
| | | chunk_size = kwargs.get("chunk_size", 60000) # 50ms |
| | | chunk_size = kwargs.get("chunk_size", 60000) # 50ms |
| | | chunk_stride_samples = int(chunk_size * frontend.fs / 1000) |
| | | |
| | | time1 = time.perf_counter() |
| | |
| | | if len(segments_i) > 0: |
| | | segments.extend(*segments_i) |
| | | |
| | | |
| | | cache["prev_samples"] = audio_sample[:-m] |
| | | if _is_final: |
| | | self.init_cache(cache) |
| | |
| | | if ibest_writer is not None: |
| | | ibest_writer["text"][key[0]] = segments |
| | | |
| | | |
| | | return results, meta_data |
| | | |
| | | |
| | | def DetectCommonFrames(self, cache: dict = {}) -> int: |
| | | if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected: |
| | | return 0 |
| | | for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1): |
| | | frame_state = FrameState.kFrameStateInvalid |
| | | frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache) |
| | | frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, |
| | | cache=cache) |
| | | self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache) |
| | | |
| | | return 0 |
| | |
| | | return 0 |
| | | for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1): |
| | | frame_state = FrameState.kFrameStateInvalid |
| | | frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache) |
| | | frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, |
| | | cache=cache) |
| | | if i != 0: |
| | | self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache) |
| | | else: |
| | |
| | | |
| | | return 0 |
| | | |
| | | def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool, cache: dict = {}) -> None: |
| | | def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool, |
| | | cache: dict = {}) -> None: |
| | | tmp_cur_frm_state = FrameState.kFrameStateInvalid |
| | | if cur_frm_state == FrameState.kFrameStateSpeech: |
| | | if math.fabs(1.0) > self.vad_opts.fe_prior_thres: |
| | |
| | | cache["stats"].pre_end_silence_detected = False |
| | | start_frame = 0 |
| | | if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: |
| | | start_frame = max(cache["stats"].data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache)) |
| | | start_frame = max(cache["stats"].data_buf_start_frame, |
| | | cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache)) |
| | | self.OnVoiceStart(start_frame, cache=cache) |
| | | cache["stats"].vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment |
| | | for t in range(start_frame + 1, cur_frm_idx + 1): |
| | |
| | | if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: |
| | | # silence timeout, return zero length decision |
| | | if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and ( |
| | | cache["stats"].continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \ |
| | | cache[ |
| | | "stats"].continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \ |
| | | or (is_final_frame and cache["stats"].number_end_time_detected == 0): |
| | | for t in range(cache["stats"].lastest_confirmed_silence_frame + 1, cur_frm_idx): |
| | | self.OnSilenceDetected(t, cache=cache) |
| | |
| | | if cur_frm_idx >= self.LatencyFrmNumAtStartPoint(cache=cache): |
| | | self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache), cache=cache) |
| | | elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: |
| | | if cache["stats"].continous_silence_frame_count * frm_shift_in_ms >= cache["stats"].max_end_sil_frame_cnt_thresh: |
| | | if cache["stats"].continous_silence_frame_count * frm_shift_in_ms >= cache[ |
| | | "stats"].max_end_sil_frame_cnt_thresh: |
| | | lookback_frame = int(cache["stats"].max_end_sil_frame_cnt_thresh / frm_shift_in_ms) |
| | | if self.vad_opts.do_extend: |
| | | lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms) |
| | |
| | | if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \ |
| | | self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value: |
| | | self.ResetDetection(cache=cache) |
| | | |
| | | |
| | | |
| | |
| | | |
| | | import torch |
| | | |
| | | from funasr.metrics import end_detect |
| | | from funasr.metrics.common import end_detect |
| | | from funasr.models.transformer.scorers.scorer_interface import PartialScorerInterface |
| | | from funasr.models.transformer.scorers.scorer_interface import ScorerInterface |
| | | |
| | |
| | | else: |
| | | remained_hyps.append(hyp) |
| | | return remained_hyps |
| | | |
| | | class BeamSearchScamaStreaming(torch.nn.Module): |
| | | """Beam search implementation.""" |
| | | |
| | | def __init__( |
| | | self, |
| | | scorers: Dict[str, ScorerInterface], |
| | | weights: Dict[str, float], |
| | | beam_size: int, |
| | | vocab_size: int, |
| | | sos: int, |
| | | eos: int, |
| | | token_list: List[str] = None, |
| | | pre_beam_ratio: float = 1.5, |
| | | pre_beam_score_key: str = None, |
| | | ): |
| | | """Initialize beam search. |
| | | |
| | | Args: |
| | | scorers (dict[str, ScorerInterface]): Dict of decoder modules |
| | | e.g., Decoder, CTCPrefixScorer, LM |
| | | The scorer will be ignored if it is `None` |
| | | weights (dict[str, float]): Dict of weights for each scorers |
| | | The scorer will be ignored if its weight is 0 |
| | | beam_size (int): The number of hypotheses kept during search |
| | | vocab_size (int): The number of vocabulary |
| | | sos (int): Start of sequence id |
| | | eos (int): End of sequence id |
| | | token_list (list[str]): List of tokens for debug log |
| | | pre_beam_score_key (str): key of scores to perform pre-beam search |
| | | pre_beam_ratio (float): beam size in the pre-beam search |
| | | will be `int(pre_beam_ratio * beam_size)` |
| | | |
| | | """ |
| | | super().__init__() |
| | | # set scorers |
| | | self.weights = weights |
| | | self.scorers = dict() |
| | | self.full_scorers = dict() |
| | | self.part_scorers = dict() |
| | | # this module dict is required for recursive cast |
| | | # `self.to(device, dtype)` in `recog.py` |
| | | self.nn_dict = torch.nn.ModuleDict() |
| | | for k, v in scorers.items(): |
| | | w = weights.get(k, 0) |
| | | if w == 0 or v is None: |
| | | continue |
| | | assert isinstance( |
| | | v, ScorerInterface |
| | | ), f"{k} ({type(v)}) does not implement ScorerInterface" |
| | | self.scorers[k] = v |
| | | if isinstance(v, PartialScorerInterface): |
| | | self.part_scorers[k] = v |
| | | else: |
| | | self.full_scorers[k] = v |
| | | if isinstance(v, torch.nn.Module): |
| | | self.nn_dict[k] = v |
| | | |
| | | # set configurations |
| | | self.sos = sos |
| | | self.eos = eos |
| | | self.token_list = token_list |
| | | self.pre_beam_size = int(pre_beam_ratio * beam_size) |
| | | self.beam_size = beam_size |
| | | self.n_vocab = vocab_size |
| | | if ( |
| | | pre_beam_score_key is not None |
| | | and pre_beam_score_key != "full" |
| | | and pre_beam_score_key not in self.full_scorers |
| | | ): |
| | | raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}") |
| | | self.pre_beam_score_key = pre_beam_score_key |
| | | self.do_pre_beam = ( |
| | | self.pre_beam_score_key is not None |
| | | and self.pre_beam_size < self.n_vocab |
| | | and len(self.part_scorers) > 0 |
| | | ) |
| | | |
| | | def init_hyp(self, x) -> List[Hypothesis]: |
| | | """Get an initial hypothesis data. |
| | | |
| | | Args: |
| | | x (torch.Tensor): The encoder output feature |
| | | |
| | | Returns: |
| | | Hypothesis: The initial hypothesis. |
| | | |
| | | """ |
| | | init_states = dict() |
| | | init_scores = dict() |
| | | for k, d in self.scorers.items(): |
| | | init_states[k] = d.init_state(x) |
| | | init_scores[k] = 0.0 |
| | | return [ |
| | | Hypothesis( |
| | | score=0.0, |
| | | scores=init_scores, |
| | | states=init_states, |
| | | yseq=torch.tensor([self.sos], device=x.device), |
| | | ) |
| | | ] |
| | | |
| | | @staticmethod |
| | | def append_token(xs: torch.Tensor, x: int) -> torch.Tensor: |
| | | """Append new token to prefix tokens. |
| | | |
| | | Args: |
| | | xs (torch.Tensor): The prefix token |
| | | x (int): The new token to append |
| | | |
| | | Returns: |
| | | torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device |
| | | |
| | | """ |
| | | x = torch.tensor([x], dtype=xs.dtype, device=xs.device) |
| | | return torch.cat((xs, x)) |
| | | |
| | | def score_full( |
| | | self, hyp: Hypothesis, |
| | | x: torch.Tensor, |
| | | x_mask: torch.Tensor = None, |
| | | pre_acoustic_embeds: torch.Tensor = None, |
| | | cache: dict={}, |
| | | ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: |
| | | """Score new hypothesis by `self.full_scorers`. |
| | | |
| | | Args: |
| | | hyp (Hypothesis): Hypothesis with prefix tokens to score |
| | | x (torch.Tensor): Corresponding input feature |
| | | |
| | | Returns: |
| | | Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of |
| | | score dict of `hyp` that has string keys of `self.full_scorers` |
| | | and tensor score values of shape: `(self.n_vocab,)`, |
| | | and state dict that has string keys |
| | | and state values of `self.full_scorers` |
| | | |
| | | """ |
| | | scores = dict() |
| | | states = dict() |
| | | for k, d in self.full_scorers.items(): |
| | | scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds, cache=cache) |
| | | return scores, states |
| | | |
| | | def score_partial( |
| | | self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor |
| | | ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: |
| | | """Score new hypothesis by `self.part_scorers`. |
| | | |
| | | Args: |
| | | hyp (Hypothesis): Hypothesis with prefix tokens to score |
| | | ids (torch.Tensor): 1D tensor of new partial tokens to score |
| | | x (torch.Tensor): Corresponding input feature |
| | | |
| | | Returns: |
| | | Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of |
| | | score dict of `hyp` that has string keys of `self.part_scorers` |
| | | and tensor score values of shape: `(len(ids),)`, |
| | | and state dict that has string keys |
| | | and state values of `self.part_scorers` |
| | | |
| | | """ |
| | | scores = dict() |
| | | states = dict() |
| | | for k, d in self.part_scorers.items(): |
| | | scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x) |
| | | return scores, states |
| | | |
| | | def beam( |
| | | self, weighted_scores: torch.Tensor, ids: torch.Tensor |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Compute topk full token ids and partial token ids. |
| | | |
| | | Args: |
| | | weighted_scores (torch.Tensor): The weighted sum scores for each tokens. |
| | | Its shape is `(self.n_vocab,)`. |
| | | ids (torch.Tensor): The partial token ids to compute topk |
| | | |
| | | Returns: |
| | | Tuple[torch.Tensor, torch.Tensor]: |
| | | The topk full token ids and partial token ids. |
| | | Their shapes are `(self.beam_size,)` |
| | | |
| | | """ |
| | | # no pre beam performed |
| | | if weighted_scores.size(0) == ids.size(0): |
| | | top_ids = weighted_scores.topk(self.beam_size)[1] |
| | | return top_ids, top_ids |
| | | |
| | | # mask pruned in pre-beam not to select in topk |
| | | tmp = weighted_scores[ids] |
| | | weighted_scores[:] = -float("inf") |
| | | weighted_scores[ids] = tmp |
| | | top_ids = weighted_scores.topk(self.beam_size)[1] |
| | | local_ids = weighted_scores[ids].topk(self.beam_size)[1] |
| | | return top_ids, local_ids |
| | | |
| | | @staticmethod |
| | | def merge_scores( |
| | | prev_scores: Dict[str, float], |
| | | next_full_scores: Dict[str, torch.Tensor], |
| | | full_idx: int, |
| | | next_part_scores: Dict[str, torch.Tensor], |
| | | part_idx: int, |
| | | ) -> Dict[str, torch.Tensor]: |
| | | """Merge scores for new hypothesis. |
| | | |
| | | Args: |
| | | prev_scores (Dict[str, float]): |
| | | The previous hypothesis scores by `self.scorers` |
| | | next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers` |
| | | full_idx (int): The next token id for `next_full_scores` |
| | | next_part_scores (Dict[str, torch.Tensor]): |
| | | scores of partial tokens by `self.part_scorers` |
| | | part_idx (int): The new token id for `next_part_scores` |
| | | |
| | | Returns: |
| | | Dict[str, torch.Tensor]: The new score dict. |
| | | Its keys are names of `self.full_scorers` and `self.part_scorers`. |
| | | Its values are scalar tensors by the scorers. |
| | | |
| | | """ |
| | | new_scores = dict() |
| | | for k, v in next_full_scores.items(): |
| | | new_scores[k] = prev_scores[k] + v[full_idx] |
| | | for k, v in next_part_scores.items(): |
| | | new_scores[k] = prev_scores[k] + v[part_idx] |
| | | return new_scores |
| | | |
| | | def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any: |
| | | """Merge states for new hypothesis. |
| | | |
| | | Args: |
| | | states: states of `self.full_scorers` |
| | | part_states: states of `self.part_scorers` |
| | | part_idx (int): The new token id for `part_scores` |
| | | |
| | | Returns: |
| | | Dict[str, torch.Tensor]: The new score dict. |
| | | Its keys are names of `self.full_scorers` and `self.part_scorers`. |
| | | Its values are states of the scorers. |
| | | |
| | | """ |
| | | new_states = dict() |
| | | for k, v in states.items(): |
| | | new_states[k] = v |
| | | for k, d in self.part_scorers.items(): |
| | | new_states[k] = d.select_state(part_states[k], part_idx) |
| | | return new_states |
| | | |
| | | def search( |
| | | self, running_hyps: List[Hypothesis], |
| | | x: torch.Tensor, |
| | | x_mask: torch.Tensor = None, |
| | | pre_acoustic_embeds: torch.Tensor = None, |
| | | cache: dict={}, |
| | | ) -> List[Hypothesis]: |
| | | """Search new tokens for running hypotheses and encoded speech x. |
| | | |
| | | Args: |
| | | running_hyps (List[Hypothesis]): Running hypotheses on beam |
| | | x (torch.Tensor): Encoded speech feature (T, D) |
| | | |
| | | Returns: |
| | | List[Hypotheses]: Best sorted hypotheses |
| | | |
| | | """ |
| | | best_hyps = [] |
| | | part_ids = torch.arange(self.n_vocab, device=x.device) # no pre-beam |
| | | for hyp in running_hyps: |
| | | # scoring |
| | | weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device) |
| | | scores, states = self.score_full(hyp, x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds, cache=cache) |
| | | for k in self.full_scorers: |
| | | weighted_scores += self.weights[k] * scores[k] |
| | | # partial scoring |
| | | if self.do_pre_beam: |
| | | pre_beam_scores = ( |
| | | weighted_scores |
| | | if self.pre_beam_score_key == "full" |
| | | else scores[self.pre_beam_score_key] |
| | | ) |
| | | part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1] |
| | | part_scores, part_states = self.score_partial(hyp, part_ids, x) |
| | | for k in self.part_scorers: |
| | | weighted_scores[part_ids] += self.weights[k] * part_scores[k] |
| | | # add previous hyp score |
| | | weighted_scores += hyp.score |
| | | |
| | | # update hyps |
| | | for j, part_j in zip(*self.beam(weighted_scores, part_ids)): |
| | | # will be (2 x beam at most) |
| | | best_hyps.append( |
| | | Hypothesis( |
| | | score=weighted_scores[j], |
| | | yseq=self.append_token(hyp.yseq, j), |
| | | scores=self.merge_scores( |
| | | hyp.scores, scores, j, part_scores, part_j |
| | | ), |
| | | states=self.merge_states(states, part_states, part_j), |
| | | ) |
| | | ) |
| | | |
| | | # sort and prune 2 x beam -> beam |
| | | best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[ |
| | | : min(len(best_hyps), self.beam_size) |
| | | ] |
| | | return best_hyps |
| | | |
| | | def forward( |
| | | self, x: torch.Tensor, |
| | | scama_mask: torch.Tensor = None, |
| | | pre_acoustic_embeds: torch.Tensor = None, |
| | | maxlenratio: float = 0.0, |
| | | minlenratio: float = 0.0, |
| | | maxlen: int = None, |
| | | minlen: int = 0, |
| | | cache:dict={}, |
| | | ) -> List[Hypothesis]: |
| | | """Perform beam search. |
| | | |
| | | Args: |
| | | x (torch.Tensor): Encoded speech feature (T, D) |
| | | maxlenratio (float): Input length ratio to obtain max output length. |
| | | If maxlenratio=0.0 (default), it uses a end-detect function |
| | | to automatically find maximum hypothesis lengths |
| | | If maxlenratio<0.0, its absolute value is interpreted |
| | | as a constant max output length. |
| | | minlenratio (float): Input length ratio to obtain min output length. |
| | | |
| | | Returns: |
| | | list[Hypothesis]: N-best decoding results |
| | | |
| | | """ |
| | | if maxlen is None: |
| | | # set length bounds |
| | | if maxlenratio == 0: |
| | | maxlen = x.shape[0] |
| | | elif maxlenratio < 0: |
| | | maxlen = -1 * int(maxlenratio) |
| | | else: |
| | | maxlen = max(1, int(maxlenratio * x.size(0))) |
| | | minlen = int(minlenratio * x.size(0)) |
| | | |
| | | logging.info("decoder input length: " + str(x.shape[0])) |
| | | logging.info("max output length: " + str(maxlen)) |
| | | logging.info("min output length: " + str(minlen)) |
| | | |
| | | # main loop of prefix search |
| | | # running_hyps = self.init_hyp(x) |
| | | running_hyps = cache["running_hyps"] |
| | | ended_hyps = [] |
| | | for i in range(maxlen): |
| | | logging.debug("position " + str(i)) |
| | | mask_enc = None |
| | | # if scama_mask is not None: |
| | | # token_num_predictor = scama_mask.size(1) |
| | | # token_id_slice = min(i, token_num_predictor-1) |
| | | # mask_enc = scama_mask[:, token_id_slice:token_id_slice+1, :] |
| | | # # if mask_enc.size(1) == 0: |
| | | # # mask_enc = scama_mask[:, -2:-1, :] |
| | | # # # mask_enc = torch.zeros_like(mask_enc) |
| | | pre_acoustic_embeds_cur = None |
| | | if pre_acoustic_embeds is not None: |
| | | b, t, d = pre_acoustic_embeds.size() |
| | | pad = torch.zeros((b, 1, d), dtype=pre_acoustic_embeds.dtype).to(device=pre_acoustic_embeds.device) |
| | | pre_acoustic_embeds = torch.cat((pre_acoustic_embeds, pad), dim=1) |
| | | token_id_slice = min(i, t) |
| | | pre_acoustic_embeds_cur = pre_acoustic_embeds[:, token_id_slice:token_id_slice+1, :] |
| | | |
| | | best = self.search(running_hyps, x, x_mask=mask_enc, pre_acoustic_embeds=pre_acoustic_embeds_cur, cache=cache["decoder"]) |
| | | # post process of one iteration |
| | | running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps) |
| | | # end detection |
| | | if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i): |
| | | logging.info(f"end detected at {i}") |
| | | break |
| | | if len(running_hyps) == 0: |
| | | logging.info("no hypothesis. Finish decoding.") |
| | | break |
| | | else: |
| | | logging.debug(f"remained hypotheses: {len(running_hyps)}") |
| | | |
| | | nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True) |
| | | # check the number of hypotheses reaching to eos |
| | | if len(nbest_hyps) == 0: |
| | | logging.warning( |
| | | "there is no N-best results, perform recognition " |
| | | "again with smaller minlenratio." |
| | | ) |
| | | return ( |
| | | [] |
| | | if minlenratio < 0.1 |
| | | else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1)) |
| | | ) |
| | | |
| | | # report the best result |
| | | for x in nbest_hyps: |
| | | yseq = "".join([self.token_list[x] for x in x.yseq]) |
| | | logging.debug("nbest: y: {}, yseq: {}, score: {}".format(x.yseq, yseq, x.score)) |
| | | best = nbest_hyps[0] |
| | | for k, v in best.scores.items(): |
| | | logging.info( |
| | | f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}" |
| | | ) |
| | | logging.info(f"total log probability: {best.score:.2f}") |
| | | logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}") |
| | | logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}") |
| | | if self.token_list is not None: |
| | | logging.info( |
| | | "best hypo: " |
| | | + "".join([self.token_list[x] for x in best.yseq[1:-1]]) |
| | | + "\n" |
| | | ) |
| | | return nbest_hyps |
| | | |
| | | def post_process( |
| | | self, |
| | | i: int, |
| | | maxlen: int, |
| | | maxlenratio: float, |
| | | running_hyps: List[Hypothesis], |
| | | ended_hyps: List[Hypothesis], |
| | | ) -> List[Hypothesis]: |
| | | """Perform post-processing of beam search iterations. |
| | | |
| | | Args: |
| | | i (int): The length of hypothesis tokens. |
| | | maxlen (int): The maximum length of tokens in beam search. |
| | | maxlenratio (int): The maximum length ratio in beam search. |
| | | running_hyps (List[Hypothesis]): The running hypotheses in beam search. |
| | | ended_hyps (List[Hypothesis]): The ended hypotheses in beam search. |
| | | |
| | | Returns: |
| | | List[Hypothesis]: The new running hypotheses. |
| | | |
| | | """ |
| | | logging.debug(f"the number of running hypotheses: {len(running_hyps)}") |
| | | if self.token_list is not None: |
| | | logging.debug( |
| | | "best hypo: " |
| | | + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]]) |
| | | ) |
| | | # add eos in the final loop to avoid that there are no ended hyps |
| | | if i == maxlen - 1: |
| | | logging.info("adding <eos> in the last position in the loop") |
| | | running_hyps = [ |
| | | h._replace(yseq=self.append_token(h.yseq, self.eos)) |
| | | for h in running_hyps |
| | | ] |
| | | |
| | | # add ended hypotheses to a final list, and removed them from current hypotheses |
| | | # (this will be a problem, number of hyps < beam) |
| | | remained_hyps = [] |
| | | for hyp in running_hyps: |
| | | if hyp.yseq[-1] == self.eos: |
| | | # e.g., Word LM needs to add final <eos> score |
| | | for k, d in chain(self.full_scorers.items(), self.part_scorers.items()): |
| | | s = d.final_score(hyp.states[k]) |
| | | hyp.scores[k] += s |
| | | hyp = hyp._replace(score=hyp.score + self.weights[k] * s) |
| | | ended_hyps.append(hyp) |
| | | else: |
| | | remained_hyps.append(hyp) |
| | | return remained_hyps |
| | |
| | | def init_beam_search(self, |
| | | **kwargs, |
| | | ): |
| | | from funasr.models.scama.beam_search import BeamSearchScama |
| | | |
| | | from funasr.models.scama.beam_search import BeamSearchScamaStreaming |
| | | |
| | | |
| | | from funasr.models.transformer.scorers.ctc import CTCPrefixScorer |
| | | from funasr.models.transformer.scorers.length_bonus import LengthBonus |
| | | |
| | |
| | | scorers["ngram"] = ngram |
| | | |
| | | weights = dict( |
| | | decoder=1.0 - kwargs.get("decoding_ctc_weight"), |
| | | decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.0), |
| | | ctc=kwargs.get("decoding_ctc_weight", 0.0), |
| | | lm=kwargs.get("lm_weight", 0.0), |
| | | ngram=kwargs.get("ngram_weight", 0.0), |
| | | length_bonus=kwargs.get("penalty", 0.0), |
| | | ) |
| | | beam_search = BeamSearchScama( |
| | | |
| | | beam_search = BeamSearchScamaStreaming( |
| | | beam_size=kwargs.get("beam_size", 2), |
| | | weights=weights, |
| | | scorers=scorers, |
| | |
| | | is_final=kwargs.get("is_final", False)) |
| | | if isinstance(encoder_out, tuple): |
| | | encoder_out = encoder_out[0] |
| | | |
| | | if "running_hyps" not in cache: |
| | | running_hyps = self.beam_search.init_hyp(encoder_out) |
| | | cache["running_hyps"] = running_hyps |
| | | |
| | | |
| | | # predictor |
| | | predictor_outs = self.calc_predictor_chunk(encoder_out, |
| | | encoder_out_lens, |
| | |
| | | |
| | | if torch.max(pre_token_length) < 1: |
| | | return [] |
| | | decoder_outs = self.cal_decoder_with_predictor_chunk(encoder_out, |
| | | encoder_out_lens, |
| | | pre_acoustic_embeds, |
| | | pre_token_length, |
| | | cache=cache |
| | | ) |
| | | decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] |
| | | |
| | | maxlen = minlen = pre_token_length |
| | | if kwargs.get("is_final", False): |
| | | maxlen += kwargs.get("token_num_relax", 5) |
| | | minlen = max(0, minlen - kwargs.get("token_num_relax", 5)) |
| | | # c. Passed the encoder result and the beam search |
| | | nbest_hyps = self.beam_search( |
| | | x=encoder_out[0], scama_mask=None, pre_acoustic_embeds=pre_acoustic_embeds, maxlen=int(maxlen), minlen=int(minlen), cache=cache, |
| | | ) |
| | | |
| | | cache["running_hyps"] = nbest_hyps |
| | | nbest_hyps = nbest_hyps[: self.nbest] |
| | | |
| | | results = [] |
| | | b, n, d = decoder_out.size() |
| | | if isinstance(key[0], (list, tuple)): |
| | | key = key[0] |
| | | for i in range(b): |
| | | x = encoder_out[i, :encoder_out_lens[i], :] |
| | | am_scores = decoder_out[i, :pre_token_length[i], :] |
| | | if self.beam_search is not None: |
| | | nbest_hyps = self.beam_search( |
| | | x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0), |
| | | minlenratio=kwargs.get("minlenratio", 0.0) |
| | | ) |
| | | |
| | | nbest_hyps = nbest_hyps[: self.nbest] |
| | | for hyp in nbest_hyps: |
| | | # assert isinstance(hyp, (Hypothesis)), type(hyp) |
| | | |
| | | # remove sos/eos and get results |
| | | last_pos = -1 |
| | | if isinstance(hyp.yseq, list): |
| | | token_int = hyp.yseq[1:last_pos] |
| | | else: |
| | | |
| | | yseq = am_scores.argmax(dim=-1) |
| | | score = am_scores.max(dim=-1)[0] |
| | | score = torch.sum(score, dim=-1) |
| | | # pad with mask tokens to ensure compatibility with sos/eos tokens |
| | | yseq = torch.tensor( |
| | | [self.sos] + yseq.tolist() + [self.eos], device=yseq.device |
| | | ) |
| | | nbest_hyps = [Hypothesis(yseq=yseq, score=score)] |
| | | for nbest_idx, hyp in enumerate(nbest_hyps): |
| | | |
| | | # remove sos/eos and get results |
| | | last_pos = -1 |
| | | if isinstance(hyp.yseq, list): |
| | | token_int = hyp.yseq[1:last_pos] |
| | | else: |
| | | token_int = hyp.yseq[1:last_pos].tolist() |
| | | |
| | | token_int = hyp.yseq[1:last_pos].tolist() |
| | | |
| | | |
| | | # remove blank symbol id, which is assumed to be 0 |
| | | token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int)) |
| | | |
| | |
| | | return results |
| | | |
| | | def init_cache(self, cache: dict = {}, **kwargs): |
| | | device = kwargs.get("device", "cuda") |
| | | |
| | | chunk_size = kwargs.get("chunk_size", [0, 10, 5]) |
| | | encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0) |
| | | decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0) |
| | |
| | | |
| | | enc_output_size = kwargs["encoder_conf"]["output_size"] |
| | | feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"] |
| | | cache_encoder = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)), |
| | | "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, |
| | | |
| | | cache_encoder = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)).to(device=device), |
| | | "cif_alphas": torch.zeros((batch_size, 1)).to(device=device), "chunk_size": chunk_size, |
| | | "encoder_chunk_look_back": encoder_chunk_look_back, "last_chunk": False, "opt": None, |
| | | "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), |
| | | "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)).to(device=device), |
| | | "tail_chunk": False} |
| | | cache["encoder"] = cache_encoder |
| | | |
| | |
| | | "chunk_size": chunk_size} |
| | | cache["decoder"] = cache_decoder |
| | | cache["frontend"] = {} |
| | | cache["prev_samples"] = torch.empty(0) |
| | | |
| | | |
| | | |
| | | cache["prev_samples"] = torch.empty(0).to(device=device) |
| | | |
| | | return cache |
| | | |
| | | def inference(self, |
| | |
| | | # init beamsearch |
| | | is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None |
| | | is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None |
| | | if self.beam_search is None and (is_use_lm or is_use_ctc): |
| | | |
| | | if self.beam_search is None: |
| | | |
| | | |
| | | logging.info("enable beam_search") |
| | | self.init_beam_search(**kwargs) |
| | | self.nbest = kwargs.get("nbest", 1) |
| | |
| | | |
| | | self._train_epoch(epoch) |
| | | |
| | | |
| | | |
| | | if self.use_ddp or self.use_fsdp: |
| | | dist.barrier() |
| | |
| | | |
| | | if self.use_ddp or self.use_fsdp: |
| | | dist.barrier() |
| | | |
| | | |
| | | |
| | | |
| | | if self.rank == 0: |
| | | self._save_checkpoint(epoch) |
| | | |
| | |
| | | |
| | | if self.use_ddp or self.use_fsdp: |
| | | dist.barrier() |
| | | |
| | | |
| | | |
| | | if self.writer: |
| | | self.writer.close() |
| | | |