凌匀
2023-02-16 d6fdd1c79364668afbf171bd18586dd1d7570d20
funasr/models/e2e_vad.py
@@ -5,7 +5,6 @@
from torch import nn
import math
from funasr.models.encoder.fsmn_encoder import FSMN
# from checkpoint import load_checkpoint
class VadStateMachine(Enum):
@@ -136,7 +135,7 @@
        self.win_size_frame = int(window_size_ms / frame_size_ms)
        self.win_sum = 0
        self.win_state = [0 for i in range(0, self.win_size_frame)]  # 初始化窗
        self.win_state = [0] * self.win_size_frame  # 初始化窗
        self.cur_win_pos = 0
        self.pre_frame_state = FrameState.kFrameStateSil
@@ -151,7 +150,7 @@
    def Reset(self) -> None:
        self.cur_win_pos = 0
        self.win_sum = 0
        self.win_state = [0 for i in range(0, self.win_size_frame)]
        self.win_state = [0] * self.win_size_frame
        self.pre_frame_state = FrameState.kFrameStateSil
        self.cur_frame_state = FrameState.kFrameStateSil
        self.voice_last_frame_count = 0
@@ -192,8 +191,8 @@
        return int(self.frame_size_ms)
class E2EVadModel(torch.nn.Module):
    def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any]):
class E2EVadModel(nn.Module):
    def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any], streaming=False):
        super(E2EVadModel, self).__init__()
        self.vad_opts = VADXOptions(**vad_post_args)
        self.windows_detector = WindowDetector(self.vad_opts.window_size_ms,
@@ -212,13 +211,13 @@
        self.confirmed_start_frame = -1
        self.confirmed_end_frame = -1
        self.number_end_time_detected = 0
        self.is_callback_with_sign = False
        self.sil_frame = 0
        self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
        self.noise_average_decibel = -100.0
        self.pre_end_silence_detected = False
        self.output_data_buf = []
        self.output_data_buf_offset = 0
        self.frame_probs = []
        self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
        self.speech_noise_thres = self.vad_opts.speech_noise_thres
@@ -226,10 +225,13 @@
        self.max_time_out = False
        self.decibel = []
        self.data_buf = None
        self.data_buf_all = None
        self.waveform = None
        self.streaming = streaming
        self.ResetDetection()
    def AllResetDetection(self):
        self.encoder.cache_reset()  # reset the in_cache in self.encoder for next query or next long sentence
        self.is_final_send = False
        self.data_buf_start_frame = 0
        self.frm_cnt = 0
@@ -240,13 +242,13 @@
        self.confirmed_start_frame = -1
        self.confirmed_end_frame = -1
        self.number_end_time_detected = 0
        self.is_callback_with_sign = False
        self.sil_frame = 0
        self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
        self.noise_average_decibel = -100.0
        self.pre_end_silence_detected = False
        self.output_data_buf = []
        self.output_data_buf_offset = 0
        self.frame_probs = []
        self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
        self.speech_noise_thres = self.vad_opts.speech_noise_thres
@@ -254,6 +256,7 @@
        self.max_time_out = False
        self.decibel = []
        self.data_buf = None
        self.data_buf_all = None
        self.waveform = None
        self.ResetDetection()
@@ -271,26 +274,32 @@
    def ComputeDecibel(self) -> None:
        frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
        frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
        self.data_buf = self.waveform[0]  # 指向self.waveform[0]
        if self.data_buf_all is None:
            self.data_buf_all = self.waveform[0]  # self.data_buf is pointed to self.waveform[0]
            self.data_buf = self.data_buf_all
        else:
            self.data_buf_all = torch.cat((self.data_buf_all, self.waveform[0]))
        for offset in range(0, self.waveform.shape[1] - frame_sample_length + 1, frame_shift_length):
            self.decibel.append(
                10 * math.log10((self.waveform[0][offset: offset + frame_sample_length]).square().sum() + \
                                0.000001))
    def ComputeScores(self, feats: torch.Tensor, feats_lengths: int) -> None:
        self.scores = self.encoder(feats)  # return B * T * D
        self.frm_cnt = feats_lengths # frame
        # return self.scores
    def ComputeScores(self, feats: torch.Tensor) -> None:
        scores = self.encoder(feats)  # return B * T * D
        assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match"
        self.vad_opts.nn_eval_block_size = scores.shape[1]
        self.frm_cnt += scores.shape[1]  # count total frames
        if self.scores is None:
            self.scores = scores  # the first calculation
        else:
            self.scores = torch.cat((self.scores, scores), dim=1)
    def PopDataBufTillFrame(self, frame_idx: int) -> None:  # need check again
        while self.data_buf_start_frame < frame_idx:
            if len(self.data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
                self.data_buf_start_frame += 1
                self.data_buf = self.waveform[0][self.data_buf_start_frame * int(
                self.data_buf = self.data_buf_all[self.data_buf_start_frame * int(
                    self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
                # for i in range(0, int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)):
                #     self.data_buf.popleft()
                # self.data_buf_start_frame += 1
    def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool,
                           last_frm_is_end_point: bool, end_point_is_sent_end: bool) -> None:
@@ -301,8 +310,9 @@
                               self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000))
            expected_sample_number += int(extra_sample)
        if end_point_is_sent_end:
            # expected_sample_number = max(expected_sample_number, len(self.data_buf))
            pass
            expected_sample_number = max(expected_sample_number, len(self.data_buf))
        if len(self.data_buf) < expected_sample_number:
            print('error in calling pop data_buf\n')
        if len(self.output_data_buf) == 0 or first_frm_is_start_point:
            self.output_data_buf.append(E2EVadSpeechBufWithDoa())
@@ -312,15 +322,18 @@
            self.output_data_buf[-1].doa = 0
        cur_seg = self.output_data_buf[-1]
        if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
            print('warning')
            print('warning\n')
        out_pos = len(cur_seg.buffer)  # cur_seg.buff现在没做任何操作
        data_to_pop = 0
        if end_point_is_sent_end:
            data_to_pop = expected_sample_number
        else:
            data_to_pop = int(frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
        # if data_to_pop > len(self.data_buf_)
        #   pass
        if data_to_pop > len(self.data_buf):
            print('VAD data_to_pop is bigger than self.data_buf.size()!!!\n')
            data_to_pop = len(self.data_buf)
            expected_sample_number = len(self.data_buf)
        cur_seg.doa = 0
        for sample_cpy_out in range(0, data_to_pop):
            # cur_seg.buffer[out_pos ++] = data_buf_.back();
@@ -329,7 +342,7 @@
            # cur_seg.buffer[out_pos++] = data_buf_.back()
            out_pos += 1
        if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
            print('warning')
            print('Something wrong with the VAD algorithm\n')
        self.data_buf_start_frame += frm_cnt
        cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms
        if first_frm_is_start_point:
@@ -346,14 +359,13 @@
    def OnVoiceDetected(self, valid_frame: int) -> None:
        self.latest_confirmed_speech_frame = valid_frame
        if True:  # is_new_api_enable_ = True
            self.PopDataToOutputBuf(valid_frame, 1, False, False, False)
        self.PopDataToOutputBuf(valid_frame, 1, False, False, False)
    def OnVoiceStart(self, start_frame: int, fake_result: bool = False) -> None:
        if self.vad_opts.do_start_point_detection:
            pass
        if self.confirmed_start_frame != -1:
            print('warning')
            print('not reset vad properly\n')
        else:
            self.confirmed_start_frame = start_frame
@@ -366,7 +378,7 @@
        if self.vad_opts.do_end_point_detection:
            pass
        if self.confirmed_end_frame != -1:
            print('warning')
            print('not reset vad properly\n')
        else:
            self.confirmed_end_frame = end_frame
        if not fake_result:
@@ -406,7 +418,6 @@
            sil_pdf_scores = [self.scores[0][t][sil_pdf_id] for sil_pdf_id in self.sil_pdf_ids]
            sum_score = sum(sil_pdf_scores)
            noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio
            # total_score = sum(self.scores[0][t][:])
            total_score = 1.0
            sum_score = total_score - sum_score
        speech_prob = math.log(sum_score)
@@ -433,25 +444,59 @@
        return frame_state
    def forward(self, feats: torch.Tensor, feats_lengths: int, waveform: torch.tensor) -> List[List[List[int]]]:
        self.AllResetDetection()
    def forward(self, feats: torch.Tensor, waveform: torch.tensor, is_final_send: bool = False) -> List[List[List[int]]]:
        self.waveform = waveform  # compute decibel for each frame
        self.ComputeDecibel()
        self.ComputeScores(feats, feats_lengths)
        assert len(self.decibel) == len(self.scores[0])  # 保证帧数一致
        self.DetectLastFrames()
        self.ComputeScores(feats)
        if not is_final_send:
            self.DetectCommonFrames()
        else:
            if self.streaming:
                self.DetectLastFrames()
            else:
                self.AllResetDetection()
                self.DetectAllFrames()  # offline decode and is_final_send == True
        segments = []
        for batch_num in range(0, feats.shape[0]):  # only support batch_size = 1 now
            segment_batch = []
            for i in range(0, len(self.output_data_buf)):
                segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms]
                segment_batch.append(segment)
            segments.append(segment_batch)
            if len(self.output_data_buf) > 0:
                for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
                    if self.output_data_buf[i].contain_seg_start_point and self.output_data_buf[
                        i].contain_seg_end_point:
                        segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms]
                        segment_batch.append(segment)
                        self.output_data_buf_offset += 1  # need update this parameter
            if segment_batch:
                segments.append(segment_batch)
        return segments
    def DetectCommonFrames(self) -> int:
        if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
            return 0
        for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
            frame_state = FrameState.kFrameStateInvalid
            frame_state = self.GetFrameState(self.frm_cnt - 1 - i)
            self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
        return 0
    def DetectLastFrames(self) -> int:
        if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
            return 0
        for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
            frame_state = FrameState.kFrameStateInvalid
            frame_state = self.GetFrameState(self.frm_cnt - 1 - i)
            if i != 0:
                self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
            else:
                self.DetectOneFrame(frame_state, self.frm_cnt - 1, True)
        return 0
    def DetectAllFrames(self) -> int:
        if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
            return 0
        if self.vad_opts.nn_eval_block_size != self.vad_opts.dcd_block_size:
            frame_state = FrameState.kFrameStateInvalid
            for t in range(0, self.frm_cnt):