| funasr/auto/auto_model.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/datasets/audio_datasets/datasets.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/frontends/wav_frontend.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/models/fsmn_vad_streaming/model.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/models/scama/beam_search.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/models/scama/model.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/train_utils/trainer.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 |
funasr/auto/auto_model.py
@@ -132,7 +132,8 @@ self.punc_kwargs = punc_kwargs self.spk_model = spk_model self.spk_kwargs = spk_kwargs self.model_path = kwargs.get("model_path", "./") self.model_path = kwargs.get("model_path") def build_model(self, **kwargs): funasr/datasets/audio_datasets/datasets.py
@@ -58,7 +58,7 @@ data_src = load_audio_text_image_video(source, fs=self.fs) if self.preprocessor_speech: data_src = self.preprocessor_speech(data_src) speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend) # speech: [b, T, d] speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend, is_final=True) # speech: [b, T, d] target = item["target"] if self.preprocessor_text: funasr/frontends/wav_frontend.py
@@ -399,9 +399,10 @@ return feats_pad, feats_lens, lfr_splice_frame_idxs def forward( self, input: torch.Tensor, input_lengths: torch.Tensor, cache: dict = {}, **kwargs self, input: torch.Tensor, input_lengths: torch.Tensor, **kwargs ): is_final = kwargs.get("is_final", False) cache = kwargs.get("cache", {}) if len(cache) == 0: self.init_cache(cache) funasr/models/fsmn_vad_streaming/model.py
@@ -23,10 +23,12 @@ kVadInStateInSpeechSegment = 2 kVadInStateEndPointDetected = 3 class FrameState(Enum): kFrameStateInvalid = -1 kFrameStateSpeech = 1 kFrameStateSil = 0 # final voice/unvoice state per frame class AudioChangeState(Enum): @@ -37,9 +39,11 @@ kChangeStateNoBegin = 4 kChangeStateInvalid = 5 class VadDetectMode(Enum): kVadSingleUtteranceDetectMode = 0 kVadMutipleUtteranceDetectMode = 1 class VADXOptions: """ @@ -47,6 +51,7 @@ Deep-FSMN for Large Vocabulary Continuous Speech Recognition https://arxiv.org/abs/1803.05030 """ def __init__( self, sample_rate: int = 16000, @@ -117,6 +122,7 @@ Deep-FSMN for Large Vocabulary Continuous Speech Recognition https://arxiv.org/abs/1803.05030 """ def __init__(self): self.start_ms = 0 self.end_ms = 0 @@ -140,6 +146,7 @@ Deep-FSMN for Large Vocabulary Continuous Speech Recognition https://arxiv.org/abs/1803.05030 """ def __init__(self): self.noise_prob = 0.0 self.speech_prob = 0.0 @@ -154,6 +161,7 @@ Deep-FSMN for Large Vocabulary Continuous Speech Recognition https://arxiv.org/abs/1803.05030 """ def __init__(self, window_size_ms: int, sil_to_speech_time: int, speech_to_sil_time: int, @@ -220,13 +228,13 @@ def FrameSizeMs(self) -> int: return int(self.frame_size_ms) class Stats(object): def __init__(self, sil_pdf_ids, max_end_sil_frame_cnt_thresh, speech_noise_thres, ): self.data_buf_start_frame = 0 self.frm_cnt = 0 self.latest_confirmed_speech_frame = 0 @@ -255,6 +263,7 @@ self.waveform = None self.last_drop_frames = 0 @tables.register("model_classes", "FsmnVADStreaming") class FsmnVADStreaming(nn.Module): """ @@ -262,6 +271,7 @@ Deep-FSMN for Large Vocabulary Continuous Speech Recognition https://arxiv.org/abs/1803.05030 """ def __init__(self, encoder: str = None, encoder_conf: Optional[Dict] = None, @@ -274,7 +284,6 @@ encoder_class = tables.encoder_classes.get(encoder) encoder = encoder_class(**encoder_conf) self.encoder = encoder def ResetDetection(self, cache: dict = {}): cache["stats"].continous_silence_frame_count = 0 @@ -292,7 +301,8 @@ drop_frames = int(cache["stats"].output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms) real_drop_frames = drop_frames - cache["stats"].last_drop_frames cache["stats"].last_drop_frames = drop_frames cache["stats"].data_buf_all = cache["stats"].data_buf_all[real_drop_frames * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):] cache["stats"].data_buf_all = cache["stats"].data_buf_all[real_drop_frames * int( self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):] cache["stats"].decibel = cache["stats"].decibel[real_drop_frames:] cache["stats"].scores = cache["stats"].scores[:, real_drop_frames:, :] @@ -300,7 +310,8 @@ frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000) frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) if cache["stats"].data_buf_all is None: cache["stats"].data_buf_all = cache["stats"].waveform[0] # cache["stats"].data_buf is pointed to cache["stats"].waveform[0] cache["stats"].data_buf_all = cache["stats"].waveform[ 0] # cache["stats"].data_buf is pointed to cache["stats"].waveform[0] cache["stats"].data_buf = cache["stats"].data_buf_all else: cache["stats"].data_buf_all = torch.cat((cache["stats"].data_buf_all, cache["stats"].waveform[0])) @@ -323,7 +334,8 @@ while cache["stats"].data_buf_start_frame < frame_idx: if len(cache["stats"].data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000): cache["stats"].data_buf_start_frame += 1 cache["stats"].data_buf = cache["stats"].data_buf_all[(cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames) * int( cache["stats"].data_buf = cache["stats"].data_buf_all[ (cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames) * int( self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):] def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool, @@ -379,6 +391,7 @@ cache["stats"].lastest_confirmed_silence_frame = valid_frame if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: self.PopDataBufTillFrame(valid_frame, cache=cache) # silence_detected_callback_ # pass @@ -487,7 +500,8 @@ segment_batch = [] if len(cache["stats"].output_data_buf) > 0: for i in range(cache["stats"].output_data_buf_offset, len(cache["stats"].output_data_buf)): if not is_final and (not cache["stats"].output_data_buf[i].contain_seg_start_point or not cache["stats"].output_data_buf[ if not is_final and (not cache["stats"].output_data_buf[i].contain_seg_start_point or not cache["stats"].output_data_buf[ i].contain_seg_end_point): continue segment = [cache["stats"].output_data_buf[i].start_ms, cache["stats"].output_data_buf[i].end_ms] @@ -580,7 +594,6 @@ if len(segments_i) > 0: segments.extend(*segments_i) cache["prev_samples"] = audio_sample[:-m] if _is_final: self.init_cache(cache) @@ -600,16 +613,15 @@ if ibest_writer is not None: ibest_writer["text"][key[0]] = segments return results, meta_data def DetectCommonFrames(self, cache: dict = {}) -> int: if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected: return 0 for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1): frame_state = FrameState.kFrameStateInvalid frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache) frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache) self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache) return 0 @@ -619,7 +631,8 @@ return 0 for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1): frame_state = FrameState.kFrameStateInvalid frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache) frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache) if i != 0: self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache) else: @@ -627,7 +640,8 @@ return 0 def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool, cache: dict = {}) -> None: def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool, cache: dict = {}) -> None: tmp_cur_frm_state = FrameState.kFrameStateInvalid if cur_frm_state == FrameState.kFrameStateSpeech: if math.fabs(1.0) > self.vad_opts.fe_prior_thres: @@ -644,7 +658,8 @@ cache["stats"].pre_end_silence_detected = False start_frame = 0 if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: start_frame = max(cache["stats"].data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache)) start_frame = max(cache["stats"].data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache)) self.OnVoiceStart(start_frame, cache=cache) cache["stats"].vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment for t in range(start_frame + 1, cur_frm_idx + 1): @@ -696,7 +711,8 @@ if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: # silence timeout, return zero length decision if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and ( cache["stats"].continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \ cache[ "stats"].continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \ or (is_final_frame and cache["stats"].number_end_time_detected == 0): for t in range(cache["stats"].lastest_confirmed_silence_frame + 1, cur_frm_idx): self.OnSilenceDetected(t, cache=cache) @@ -707,7 +723,8 @@ if cur_frm_idx >= self.LatencyFrmNumAtStartPoint(cache=cache): self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache), cache=cache) elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: if cache["stats"].continous_silence_frame_count * frm_shift_in_ms >= cache["stats"].max_end_sil_frame_cnt_thresh: if cache["stats"].continous_silence_frame_count * frm_shift_in_ms >= cache[ "stats"].max_end_sil_frame_cnt_thresh: lookback_frame = int(cache["stats"].max_end_sil_frame_cnt_thresh / frm_shift_in_ms) if self.vad_opts.do_extend: lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms) @@ -731,6 +748,5 @@ if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \ self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value: self.ResetDetection(cache=cache) funasr/models/scama/beam_search.py
@@ -11,7 +11,7 @@ import torch from funasr.metrics import end_detect from funasr.metrics.common import end_detect from funasr.models.transformer.scorers.scorer_interface import PartialScorerInterface from funasr.models.transformer.scorers.scorer_interface import ScorerInterface @@ -494,3 +494,468 @@ else: remained_hyps.append(hyp) return remained_hyps class BeamSearchScamaStreaming(torch.nn.Module): """Beam search implementation.""" def __init__( self, scorers: Dict[str, ScorerInterface], weights: Dict[str, float], beam_size: int, vocab_size: int, sos: int, eos: int, token_list: List[str] = None, pre_beam_ratio: float = 1.5, pre_beam_score_key: str = None, ): """Initialize beam search. Args: scorers (dict[str, ScorerInterface]): Dict of decoder modules e.g., Decoder, CTCPrefixScorer, LM The scorer will be ignored if it is `None` weights (dict[str, float]): Dict of weights for each scorers The scorer will be ignored if its weight is 0 beam_size (int): The number of hypotheses kept during search vocab_size (int): The number of vocabulary sos (int): Start of sequence id eos (int): End of sequence id token_list (list[str]): List of tokens for debug log pre_beam_score_key (str): key of scores to perform pre-beam search pre_beam_ratio (float): beam size in the pre-beam search will be `int(pre_beam_ratio * beam_size)` """ super().__init__() # set scorers self.weights = weights self.scorers = dict() self.full_scorers = dict() self.part_scorers = dict() # this module dict is required for recursive cast # `self.to(device, dtype)` in `recog.py` self.nn_dict = torch.nn.ModuleDict() for k, v in scorers.items(): w = weights.get(k, 0) if w == 0 or v is None: continue assert isinstance( v, ScorerInterface ), f"{k} ({type(v)}) does not implement ScorerInterface" self.scorers[k] = v if isinstance(v, PartialScorerInterface): self.part_scorers[k] = v else: self.full_scorers[k] = v if isinstance(v, torch.nn.Module): self.nn_dict[k] = v # set configurations self.sos = sos self.eos = eos self.token_list = token_list self.pre_beam_size = int(pre_beam_ratio * beam_size) self.beam_size = beam_size self.n_vocab = vocab_size if ( pre_beam_score_key is not None and pre_beam_score_key != "full" and pre_beam_score_key not in self.full_scorers ): raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}") self.pre_beam_score_key = pre_beam_score_key self.do_pre_beam = ( self.pre_beam_score_key is not None and self.pre_beam_size < self.n_vocab and len(self.part_scorers) > 0 ) def init_hyp(self, x) -> List[Hypothesis]: """Get an initial hypothesis data. Args: x (torch.Tensor): The encoder output feature Returns: Hypothesis: The initial hypothesis. """ init_states = dict() init_scores = dict() for k, d in self.scorers.items(): init_states[k] = d.init_state(x) init_scores[k] = 0.0 return [ Hypothesis( score=0.0, scores=init_scores, states=init_states, yseq=torch.tensor([self.sos], device=x.device), ) ] @staticmethod def append_token(xs: torch.Tensor, x: int) -> torch.Tensor: """Append new token to prefix tokens. Args: xs (torch.Tensor): The prefix token x (int): The new token to append Returns: torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device """ x = torch.tensor([x], dtype=xs.dtype, device=xs.device) return torch.cat((xs, x)) def score_full( self, hyp: Hypothesis, x: torch.Tensor, x_mask: torch.Tensor = None, pre_acoustic_embeds: torch.Tensor = None, cache: dict={}, ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: """Score new hypothesis by `self.full_scorers`. Args: hyp (Hypothesis): Hypothesis with prefix tokens to score x (torch.Tensor): Corresponding input feature Returns: Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of score dict of `hyp` that has string keys of `self.full_scorers` and tensor score values of shape: `(self.n_vocab,)`, and state dict that has string keys and state values of `self.full_scorers` """ scores = dict() states = dict() for k, d in self.full_scorers.items(): scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds, cache=cache) return scores, states def score_partial( self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: """Score new hypothesis by `self.part_scorers`. Args: hyp (Hypothesis): Hypothesis with prefix tokens to score ids (torch.Tensor): 1D tensor of new partial tokens to score x (torch.Tensor): Corresponding input feature Returns: Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of score dict of `hyp` that has string keys of `self.part_scorers` and tensor score values of shape: `(len(ids),)`, and state dict that has string keys and state values of `self.part_scorers` """ scores = dict() states = dict() for k, d in self.part_scorers.items(): scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x) return scores, states def beam( self, weighted_scores: torch.Tensor, ids: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute topk full token ids and partial token ids. Args: weighted_scores (torch.Tensor): The weighted sum scores for each tokens. Its shape is `(self.n_vocab,)`. ids (torch.Tensor): The partial token ids to compute topk Returns: Tuple[torch.Tensor, torch.Tensor]: The topk full token ids and partial token ids. Their shapes are `(self.beam_size,)` """ # no pre beam performed if weighted_scores.size(0) == ids.size(0): top_ids = weighted_scores.topk(self.beam_size)[1] return top_ids, top_ids # mask pruned in pre-beam not to select in topk tmp = weighted_scores[ids] weighted_scores[:] = -float("inf") weighted_scores[ids] = tmp top_ids = weighted_scores.topk(self.beam_size)[1] local_ids = weighted_scores[ids].topk(self.beam_size)[1] return top_ids, local_ids @staticmethod def merge_scores( prev_scores: Dict[str, float], next_full_scores: Dict[str, torch.Tensor], full_idx: int, next_part_scores: Dict[str, torch.Tensor], part_idx: int, ) -> Dict[str, torch.Tensor]: """Merge scores for new hypothesis. Args: prev_scores (Dict[str, float]): The previous hypothesis scores by `self.scorers` next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers` full_idx (int): The next token id for `next_full_scores` next_part_scores (Dict[str, torch.Tensor]): scores of partial tokens by `self.part_scorers` part_idx (int): The new token id for `next_part_scores` Returns: Dict[str, torch.Tensor]: The new score dict. Its keys are names of `self.full_scorers` and `self.part_scorers`. Its values are scalar tensors by the scorers. """ new_scores = dict() for k, v in next_full_scores.items(): new_scores[k] = prev_scores[k] + v[full_idx] for k, v in next_part_scores.items(): new_scores[k] = prev_scores[k] + v[part_idx] return new_scores def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any: """Merge states for new hypothesis. Args: states: states of `self.full_scorers` part_states: states of `self.part_scorers` part_idx (int): The new token id for `part_scores` Returns: Dict[str, torch.Tensor]: The new score dict. Its keys are names of `self.full_scorers` and `self.part_scorers`. Its values are states of the scorers. """ new_states = dict() for k, v in states.items(): new_states[k] = v for k, d in self.part_scorers.items(): new_states[k] = d.select_state(part_states[k], part_idx) return new_states def search( self, running_hyps: List[Hypothesis], x: torch.Tensor, x_mask: torch.Tensor = None, pre_acoustic_embeds: torch.Tensor = None, cache: dict={}, ) -> List[Hypothesis]: """Search new tokens for running hypotheses and encoded speech x. Args: running_hyps (List[Hypothesis]): Running hypotheses on beam x (torch.Tensor): Encoded speech feature (T, D) Returns: List[Hypotheses]: Best sorted hypotheses """ best_hyps = [] part_ids = torch.arange(self.n_vocab, device=x.device) # no pre-beam for hyp in running_hyps: # scoring weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device) scores, states = self.score_full(hyp, x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds, cache=cache) for k in self.full_scorers: weighted_scores += self.weights[k] * scores[k] # partial scoring if self.do_pre_beam: pre_beam_scores = ( weighted_scores if self.pre_beam_score_key == "full" else scores[self.pre_beam_score_key] ) part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1] part_scores, part_states = self.score_partial(hyp, part_ids, x) for k in self.part_scorers: weighted_scores[part_ids] += self.weights[k] * part_scores[k] # add previous hyp score weighted_scores += hyp.score # update hyps for j, part_j in zip(*self.beam(weighted_scores, part_ids)): # will be (2 x beam at most) best_hyps.append( Hypothesis( score=weighted_scores[j], yseq=self.append_token(hyp.yseq, j), scores=self.merge_scores( hyp.scores, scores, j, part_scores, part_j ), states=self.merge_states(states, part_states, part_j), ) ) # sort and prune 2 x beam -> beam best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[ : min(len(best_hyps), self.beam_size) ] return best_hyps def forward( self, x: torch.Tensor, scama_mask: torch.Tensor = None, pre_acoustic_embeds: torch.Tensor = None, maxlenratio: float = 0.0, minlenratio: float = 0.0, maxlen: int = None, minlen: int = 0, cache:dict={}, ) -> List[Hypothesis]: """Perform beam search. Args: x (torch.Tensor): Encoded speech feature (T, D) maxlenratio (float): Input length ratio to obtain max output length. If maxlenratio=0.0 (default), it uses a end-detect function to automatically find maximum hypothesis lengths If maxlenratio<0.0, its absolute value is interpreted as a constant max output length. minlenratio (float): Input length ratio to obtain min output length. Returns: list[Hypothesis]: N-best decoding results """ if maxlen is None: # set length bounds if maxlenratio == 0: maxlen = x.shape[0] elif maxlenratio < 0: maxlen = -1 * int(maxlenratio) else: maxlen = max(1, int(maxlenratio * x.size(0))) minlen = int(minlenratio * x.size(0)) logging.info("decoder input length: " + str(x.shape[0])) logging.info("max output length: " + str(maxlen)) logging.info("min output length: " + str(minlen)) # main loop of prefix search # running_hyps = self.init_hyp(x) running_hyps = cache["running_hyps"] ended_hyps = [] for i in range(maxlen): logging.debug("position " + str(i)) mask_enc = None # if scama_mask is not None: # token_num_predictor = scama_mask.size(1) # token_id_slice = min(i, token_num_predictor-1) # mask_enc = scama_mask[:, token_id_slice:token_id_slice+1, :] # # if mask_enc.size(1) == 0: # # mask_enc = scama_mask[:, -2:-1, :] # # # mask_enc = torch.zeros_like(mask_enc) pre_acoustic_embeds_cur = None if pre_acoustic_embeds is not None: b, t, d = pre_acoustic_embeds.size() pad = torch.zeros((b, 1, d), dtype=pre_acoustic_embeds.dtype).to(device=pre_acoustic_embeds.device) pre_acoustic_embeds = torch.cat((pre_acoustic_embeds, pad), dim=1) token_id_slice = min(i, t) pre_acoustic_embeds_cur = pre_acoustic_embeds[:, token_id_slice:token_id_slice+1, :] best = self.search(running_hyps, x, x_mask=mask_enc, pre_acoustic_embeds=pre_acoustic_embeds_cur, cache=cache["decoder"]) # post process of one iteration running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps) # end detection if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i): logging.info(f"end detected at {i}") break if len(running_hyps) == 0: logging.info("no hypothesis. Finish decoding.") break else: logging.debug(f"remained hypotheses: {len(running_hyps)}") nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True) # check the number of hypotheses reaching to eos if len(nbest_hyps) == 0: logging.warning( "there is no N-best results, perform recognition " "again with smaller minlenratio." ) return ( [] if minlenratio < 0.1 else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1)) ) # report the best result for x in nbest_hyps: yseq = "".join([self.token_list[x] for x in x.yseq]) logging.debug("nbest: y: {}, yseq: {}, score: {}".format(x.yseq, yseq, x.score)) best = nbest_hyps[0] for k, v in best.scores.items(): logging.info( f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}" ) logging.info(f"total log probability: {best.score:.2f}") logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}") logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}") if self.token_list is not None: logging.info( "best hypo: " + "".join([self.token_list[x] for x in best.yseq[1:-1]]) + "\n" ) return nbest_hyps def post_process( self, i: int, maxlen: int, maxlenratio: float, running_hyps: List[Hypothesis], ended_hyps: List[Hypothesis], ) -> List[Hypothesis]: """Perform post-processing of beam search iterations. Args: i (int): The length of hypothesis tokens. maxlen (int): The maximum length of tokens in beam search. maxlenratio (int): The maximum length ratio in beam search. running_hyps (List[Hypothesis]): The running hypotheses in beam search. ended_hyps (List[Hypothesis]): The ended hypotheses in beam search. Returns: List[Hypothesis]: The new running hypotheses. """ logging.debug(f"the number of running hypotheses: {len(running_hyps)}") if self.token_list is not None: logging.debug( "best hypo: " + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]]) ) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info("adding <eos> in the last position in the loop") running_hyps = [ h._replace(yseq=self.append_token(h.yseq, self.eos)) for h in running_hyps ] # add ended hypotheses to a final list, and removed them from current hypotheses # (this will be a problem, number of hyps < beam) remained_hyps = [] for hyp in running_hyps: if hyp.yseq[-1] == self.eos: # e.g., Word LM needs to add final <eos> score for k, d in chain(self.full_scorers.items(), self.part_scorers.items()): s = d.final_score(hyp.states[k]) hyp.scores[k] += s hyp = hyp._replace(score=hyp.score + self.weights[k] * s) ended_hyps.append(hyp) else: remained_hyps.append(hyp) return remained_hyps funasr/models/scama/model.py
@@ -436,7 +436,10 @@ def init_beam_search(self, **kwargs, ): from funasr.models.scama.beam_search import BeamSearchScama from funasr.models.scama.beam_search import BeamSearchScamaStreaming from funasr.models.transformer.scorers.ctc import CTCPrefixScorer from funasr.models.transformer.scorers.length_bonus import LengthBonus @@ -460,13 +463,14 @@ scorers["ngram"] = ngram weights = dict( decoder=1.0 - kwargs.get("decoding_ctc_weight"), decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.0), ctc=kwargs.get("decoding_ctc_weight", 0.0), lm=kwargs.get("lm_weight", 0.0), ngram=kwargs.get("ngram_weight", 0.0), length_bonus=kwargs.get("penalty", 0.0), ) beam_search = BeamSearchScama( beam_search = BeamSearchScamaStreaming( beam_size=kwargs.get("beam_size", 2), weights=weights, scorers=scorers, @@ -499,6 +503,10 @@ is_final=kwargs.get("is_final", False)) if isinstance(encoder_out, tuple): encoder_out = encoder_out[0] if "running_hyps" not in cache: running_hyps = self.beam_search.init_hyp(encoder_out) cache["running_hyps"] = running_hyps # predictor predictor_outs = self.calc_predictor_chunk(encoder_out, @@ -513,39 +521,21 @@ if torch.max(pre_token_length) < 1: return [] decoder_outs = self.cal_decoder_with_predictor_chunk(encoder_out, encoder_out_lens, pre_acoustic_embeds, pre_token_length, cache=cache maxlen = minlen = pre_token_length if kwargs.get("is_final", False): maxlen += kwargs.get("token_num_relax", 5) minlen = max(0, minlen - kwargs.get("token_num_relax", 5)) # c. Passed the encoder result and the beam search nbest_hyps = self.beam_search( x=encoder_out[0], scama_mask=None, pre_acoustic_embeds=pre_acoustic_embeds, maxlen=int(maxlen), minlen=int(minlen), cache=cache, ) decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] cache["running_hyps"] = nbest_hyps nbest_hyps = nbest_hyps[: self.nbest] results = [] b, n, d = decoder_out.size() if isinstance(key[0], (list, tuple)): key = key[0] for i in range(b): x = encoder_out[i, :encoder_out_lens[i], :] am_scores = decoder_out[i, :pre_token_length[i], :] if self.beam_search is not None: nbest_hyps = self.beam_search( x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0) ) nbest_hyps = nbest_hyps[: self.nbest] else: yseq = am_scores.argmax(dim=-1) score = am_scores.max(dim=-1)[0] score = torch.sum(score, dim=-1) # pad with mask tokens to ensure compatibility with sos/eos tokens yseq = torch.tensor( [self.sos] + yseq.tolist() + [self.eos], device=yseq.device ) nbest_hyps = [Hypothesis(yseq=yseq, score=score)] for nbest_idx, hyp in enumerate(nbest_hyps): for hyp in nbest_hyps: # assert isinstance(hyp, (Hypothesis)), type(hyp) # remove sos/eos and get results last_pos = -1 @@ -553,6 +543,7 @@ token_int = hyp.yseq[1:last_pos] else: token_int = hyp.yseq[1:last_pos].tolist() # remove blank symbol id, which is assumed to be 0 token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int)) @@ -568,6 +559,8 @@ return results def init_cache(self, cache: dict = {}, **kwargs): device = kwargs.get("device", "cuda") chunk_size = kwargs.get("chunk_size", [0, 10, 5]) encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0) decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0) @@ -575,10 +568,11 @@ enc_output_size = kwargs["encoder_conf"]["output_size"] feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"] cache_encoder = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)), "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, cache_encoder = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)).to(device=device), "cif_alphas": torch.zeros((batch_size, 1)).to(device=device), "chunk_size": chunk_size, "encoder_chunk_look_back": encoder_chunk_look_back, "last_chunk": False, "opt": None, "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)).to(device=device), "tail_chunk": False} cache["encoder"] = cache_encoder @@ -586,7 +580,9 @@ "chunk_size": chunk_size} cache["decoder"] = cache_decoder cache["frontend"] = {} cache["prev_samples"] = torch.empty(0) cache["prev_samples"] = torch.empty(0).to(device=device) return cache @@ -603,7 +599,10 @@ # init beamsearch is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None if self.beam_search is None and (is_use_lm or is_use_ctc): if self.beam_search is None: logging.info("enable beam_search") self.init_beam_search(**kwargs) self.nbest = kwargs.get("nbest", 1) funasr/train_utils/trainer.py
@@ -149,6 +149,7 @@ self._train_epoch(epoch) if self.use_ddp or self.use_fsdp: dist.barrier() @@ -173,6 +174,7 @@ if self.use_ddp or self.use_fsdp: dist.barrier() if self.writer: self.writer.close()