Merge pull request #123 from alibaba-damo-academy/dev_zly
support vad streaming decoder
| | |
| | | import argparse |
| | | import logging |
| | | import sys |
| | | import json |
| | | from pathlib import Path |
| | | from typing import Any |
| | | from typing import List |
| | |
| | | feats_len = feats_len.int() |
| | | else: |
| | | raise Exception("Need to extract feats first, please configure frontend configuration") |
| | | batch = {"feats": feats, "feats_lengths": feats_len, "waveform": speech} |
| | | # batch = {"feats": feats, "waveform": speech, "is_final_send": True} |
| | | # segments = self.vad_model(**batch) |
| | | |
| | | # b. Forward Encoder sreaming |
| | | segments = [] |
| | | step = 6000 |
| | | t_offset = 0 |
| | | for t_offset in range(0, feats_len, min(step, feats_len - t_offset)): |
| | | if t_offset + step >= feats_len - 1: |
| | | step = feats_len - t_offset |
| | | is_final_send = True |
| | | else: |
| | | is_final_send = False |
| | | batch = { |
| | | "feats": feats[:, t_offset:t_offset + step, :], |
| | | "waveform": speech[:, t_offset * 160:min(speech.shape[-1], (t_offset + step - 1) * 160 + 400)], |
| | | "is_final_send": is_final_send |
| | | } |
| | | # a. To device |
| | | batch = to_device(batch, device=self.device) |
| | | |
| | | # b. Forward Encoder |
| | | segments = self.vad_model(**batch) |
| | | segments_part = self.vad_model(**batch) |
| | | if segments_part: |
| | | segments += segments_part |
| | | #print(segments) |
| | | |
| | | return segments |
| | | |
| | | |
| | | |
| | | |
| | | def inference( |
| | |
| | | ) |
| | | return inference_pipeline(data_path_and_name_and_type, raw_inputs) |
| | | |
| | | |
| | | def inference_modelscope( |
| | | batch_size: int, |
| | | ngpu: int, |
| | |
| | | dtype: str = "float32", |
| | | seed: int = 0, |
| | | num_workers: int = 1, |
| | | param_dict: dict = None, |
| | | **kwargs, |
| | | ): |
| | | assert check_argument_types() |
| | |
| | | # do vad segment |
| | | results = speech2vadsegment(**batch) |
| | | for i, _ in enumerate(keys): |
| | | results[i] = json.dumps(results[i]) |
| | | item = {'key': keys[i], 'value': results[i]} |
| | | vad_results.append(item) |
| | | if writer is not None: |
| | | results[i] = json.loads(results[i]) |
| | | ibest_writer["text"][keys[i]] = "{}".format(results[i]) |
| | | |
| | | return vad_results |
| | |
| | | |
| | | |
| | | def inference_launch(mode, **kwargs): |
| | | if mode == "vad": |
| | | if mode == "offline": |
| | | from funasr.bin.vad_inference import inference_modelscope |
| | | return inference_modelscope(**kwargs) |
| | | elif mode == "online": |
| | | from funasr.bin.vad_inference_online import inference_modelscope |
| | | return inference_modelscope(**kwargs) |
| | | else: |
| | | logging.info("Unknown decoding mode: {}".format(mode)) |
| | | return None |
| | | |
| | | |
| | | def main(cmd=None): |
| | | print(get_commandline_args(), file=sys.stderr) |
| | |
| | | from torch import nn |
| | | import math |
| | | from funasr.models.encoder.fsmn_encoder import FSMN |
| | | # from checkpoint import load_checkpoint |
| | | |
| | | |
| | | class VadStateMachine(Enum): |
| | |
| | | |
| | | self.win_size_frame = int(window_size_ms / frame_size_ms) |
| | | self.win_sum = 0 |
| | | self.win_state = [0 for i in range(0, self.win_size_frame)] # 初始化窗 |
| | | self.win_state = [0] * self.win_size_frame # 初始化窗 |
| | | |
| | | self.cur_win_pos = 0 |
| | | self.pre_frame_state = FrameState.kFrameStateSil |
| | |
| | | def Reset(self) -> None: |
| | | self.cur_win_pos = 0 |
| | | self.win_sum = 0 |
| | | self.win_state = [0 for i in range(0, self.win_size_frame)] |
| | | self.win_state = [0] * self.win_size_frame |
| | | self.pre_frame_state = FrameState.kFrameStateSil |
| | | self.cur_frame_state = FrameState.kFrameStateSil |
| | | self.voice_last_frame_count = 0 |
| | |
| | | return int(self.frame_size_ms) |
| | | |
| | | |
| | | class E2EVadModel(torch.nn.Module): |
| | | def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any]): |
| | | class E2EVadModel(nn.Module): |
| | | def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any], streaming=False): |
| | | super(E2EVadModel, self).__init__() |
| | | self.vad_opts = VADXOptions(**vad_post_args) |
| | | self.windows_detector = WindowDetector(self.vad_opts.window_size_ms, |
| | |
| | | self.confirmed_start_frame = -1 |
| | | self.confirmed_end_frame = -1 |
| | | self.number_end_time_detected = 0 |
| | | self.is_callback_with_sign = False |
| | | self.sil_frame = 0 |
| | | self.sil_pdf_ids = self.vad_opts.sil_pdf_ids |
| | | self.noise_average_decibel = -100.0 |
| | | self.pre_end_silence_detected = False |
| | | |
| | | self.output_data_buf = [] |
| | | self.output_data_buf_offset = 0 |
| | | self.frame_probs = [] |
| | | self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres |
| | | self.speech_noise_thres = self.vad_opts.speech_noise_thres |
| | |
| | | self.max_time_out = False |
| | | self.decibel = [] |
| | | self.data_buf = None |
| | | self.data_buf_all = None |
| | | self.waveform = None |
| | | self.streaming = streaming |
| | | self.ResetDetection() |
| | | |
| | | def AllResetDetection(self): |
| | | self.encoder.cache_reset() # reset the in_cache in self.encoder for next query or next long sentence |
| | | self.is_final_send = False |
| | | self.data_buf_start_frame = 0 |
| | | self.frm_cnt = 0 |
| | |
| | | self.confirmed_start_frame = -1 |
| | | self.confirmed_end_frame = -1 |
| | | self.number_end_time_detected = 0 |
| | | self.is_callback_with_sign = False |
| | | self.sil_frame = 0 |
| | | self.sil_pdf_ids = self.vad_opts.sil_pdf_ids |
| | | self.noise_average_decibel = -100.0 |
| | | self.pre_end_silence_detected = False |
| | | |
| | | self.output_data_buf = [] |
| | | self.output_data_buf_offset = 0 |
| | | self.frame_probs = [] |
| | | self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres |
| | | self.speech_noise_thres = self.vad_opts.speech_noise_thres |
| | |
| | | self.max_time_out = False |
| | | self.decibel = [] |
| | | self.data_buf = None |
| | | self.data_buf_all = None |
| | | self.waveform = None |
| | | self.ResetDetection() |
| | | |
| | |
| | | def ComputeDecibel(self) -> None: |
| | | frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000) |
| | | frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) |
| | | self.data_buf = self.waveform[0] # 指向self.waveform[0] |
| | | if self.data_buf_all is None: |
| | | self.data_buf_all = self.waveform[0] # self.data_buf is pointed to self.waveform[0] |
| | | self.data_buf = self.data_buf_all |
| | | else: |
| | | self.data_buf_all = torch.cat((self.data_buf_all, self.waveform[0])) |
| | | for offset in range(0, self.waveform.shape[1] - frame_sample_length + 1, frame_shift_length): |
| | | self.decibel.append( |
| | | 10 * math.log10((self.waveform[0][offset: offset + frame_sample_length]).square().sum() + \ |
| | | 0.000001)) |
| | | |
| | | def ComputeScores(self, feats: torch.Tensor, feats_lengths: int) -> None: |
| | | self.scores = self.encoder(feats) # return B * T * D |
| | | self.frm_cnt = feats_lengths # frame |
| | | # return self.scores |
| | | def ComputeScores(self, feats: torch.Tensor) -> None: |
| | | scores = self.encoder(feats) # return B * T * D |
| | | assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match" |
| | | self.vad_opts.nn_eval_block_size = scores.shape[1] |
| | | self.frm_cnt += scores.shape[1] # count total frames |
| | | if self.scores is None: |
| | | self.scores = scores # the first calculation |
| | | else: |
| | | self.scores = torch.cat((self.scores, scores), dim=1) |
| | | |
| | | def PopDataBufTillFrame(self, frame_idx: int) -> None: # need check again |
| | | while self.data_buf_start_frame < frame_idx: |
| | | if len(self.data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000): |
| | | self.data_buf_start_frame += 1 |
| | | self.data_buf = self.waveform[0][self.data_buf_start_frame * int( |
| | | self.data_buf = self.data_buf_all[self.data_buf_start_frame * int( |
| | | self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):] |
| | | # for i in range(0, int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)): |
| | | # self.data_buf.popleft() |
| | | # self.data_buf_start_frame += 1 |
| | | |
| | | def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool, |
| | | last_frm_is_end_point: bool, end_point_is_sent_end: bool) -> None: |
| | |
| | | self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)) |
| | | expected_sample_number += int(extra_sample) |
| | | if end_point_is_sent_end: |
| | | # expected_sample_number = max(expected_sample_number, len(self.data_buf)) |
| | | pass |
| | | expected_sample_number = max(expected_sample_number, len(self.data_buf)) |
| | | if len(self.data_buf) < expected_sample_number: |
| | | print('error in calling pop data_buf\n') |
| | | |
| | | if len(self.output_data_buf) == 0 or first_frm_is_start_point: |
| | | self.output_data_buf.append(E2EVadSpeechBufWithDoa()) |
| | |
| | | self.output_data_buf[-1].doa = 0 |
| | | cur_seg = self.output_data_buf[-1] |
| | | if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms: |
| | | print('warning') |
| | | print('warning\n') |
| | | out_pos = len(cur_seg.buffer) # cur_seg.buff现在没做任何操作 |
| | | data_to_pop = 0 |
| | | if end_point_is_sent_end: |
| | | data_to_pop = expected_sample_number |
| | | else: |
| | | data_to_pop = int(frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) |
| | | # if data_to_pop > len(self.data_buf_) |
| | | # pass |
| | | if data_to_pop > len(self.data_buf): |
| | | print('VAD data_to_pop is bigger than self.data_buf.size()!!!\n') |
| | | data_to_pop = len(self.data_buf) |
| | | expected_sample_number = len(self.data_buf) |
| | | |
| | | cur_seg.doa = 0 |
| | | for sample_cpy_out in range(0, data_to_pop): |
| | | # cur_seg.buffer[out_pos ++] = data_buf_.back(); |
| | |
| | | # cur_seg.buffer[out_pos++] = data_buf_.back() |
| | | out_pos += 1 |
| | | if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms: |
| | | print('warning') |
| | | print('Something wrong with the VAD algorithm\n') |
| | | self.data_buf_start_frame += frm_cnt |
| | | cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms |
| | | if first_frm_is_start_point: |
| | |
| | | |
| | | def OnVoiceDetected(self, valid_frame: int) -> None: |
| | | self.latest_confirmed_speech_frame = valid_frame |
| | | if True: # is_new_api_enable_ = True |
| | | self.PopDataToOutputBuf(valid_frame, 1, False, False, False) |
| | | |
| | | def OnVoiceStart(self, start_frame: int, fake_result: bool = False) -> None: |
| | | if self.vad_opts.do_start_point_detection: |
| | | pass |
| | | if self.confirmed_start_frame != -1: |
| | | print('warning') |
| | | print('not reset vad properly\n') |
| | | else: |
| | | self.confirmed_start_frame = start_frame |
| | | |
| | |
| | | if self.vad_opts.do_end_point_detection: |
| | | pass |
| | | if self.confirmed_end_frame != -1: |
| | | print('warning') |
| | | print('not reset vad properly\n') |
| | | else: |
| | | self.confirmed_end_frame = end_frame |
| | | if not fake_result: |
| | |
| | | sil_pdf_scores = [self.scores[0][t][sil_pdf_id] for sil_pdf_id in self.sil_pdf_ids] |
| | | sum_score = sum(sil_pdf_scores) |
| | | noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio |
| | | # total_score = sum(self.scores[0][t][:]) |
| | | total_score = 1.0 |
| | | sum_score = total_score - sum_score |
| | | speech_prob = math.log(sum_score) |
| | |
| | | |
| | | return frame_state |
| | | |
| | | def forward(self, feats: torch.Tensor, feats_lengths: int, waveform: torch.tensor) -> List[List[List[int]]]: |
| | | self.AllResetDetection() |
| | | def forward(self, feats: torch.Tensor, waveform: torch.tensor, is_final_send: bool = False) -> List[List[List[int]]]: |
| | | self.waveform = waveform # compute decibel for each frame |
| | | self.ComputeDecibel() |
| | | self.ComputeScores(feats, feats_lengths) |
| | | assert len(self.decibel) == len(self.scores[0]) # 保证帧数一致 |
| | | self.ComputeScores(feats) |
| | | if not is_final_send: |
| | | self.DetectCommonFrames() |
| | | else: |
| | | if self.streaming: |
| | | self.DetectLastFrames() |
| | | else: |
| | | self.AllResetDetection() |
| | | self.DetectAllFrames() # offline decode and is_final_send == True |
| | | segments = [] |
| | | for batch_num in range(0, feats.shape[0]): # only support batch_size = 1 now |
| | | segment_batch = [] |
| | | for i in range(0, len(self.output_data_buf)): |
| | | if len(self.output_data_buf) > 0: |
| | | for i in range(self.output_data_buf_offset, len(self.output_data_buf)): |
| | | if self.output_data_buf[i].contain_seg_start_point and self.output_data_buf[ |
| | | i].contain_seg_end_point: |
| | | segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms] |
| | | segment_batch.append(segment) |
| | | self.output_data_buf_offset += 1 # need update this parameter |
| | | if segment_batch: |
| | | segments.append(segment_batch) |
| | | |
| | | return segments |
| | | |
| | | def DetectCommonFrames(self) -> int: |
| | | if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected: |
| | | return 0 |
| | | for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1): |
| | | frame_state = FrameState.kFrameStateInvalid |
| | | frame_state = self.GetFrameState(self.frm_cnt - 1 - i) |
| | | self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False) |
| | | |
| | | return 0 |
| | | |
| | | def DetectLastFrames(self) -> int: |
| | | if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected: |
| | | return 0 |
| | | for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1): |
| | | frame_state = FrameState.kFrameStateInvalid |
| | | frame_state = self.GetFrameState(self.frm_cnt - 1 - i) |
| | | if i != 0: |
| | | self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False) |
| | | else: |
| | | self.DetectOneFrame(frame_state, self.frm_cnt - 1, True) |
| | | |
| | | return 0 |
| | | |
| | | def DetectAllFrames(self) -> int: |
| | | if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected: |
| | | return 0 |
| | | if self.vad_opts.nn_eval_block_size != self.vad_opts.dcd_block_size: |
| | | frame_state = FrameState.kFrameStateInvalid |
| | | for t in range(0, self.frm_cnt): |
| | |
| | | from typing import Tuple, Dict |
| | | import copy |
| | | |
| | | import numpy as np |
| | | import torch |
| | | import torch.nn as nn |
| | | import torch.nn.functional as F |
| | | |
| | | from typing import Tuple |
| | | |
| | | |
| | | class LinearTransform(nn.Module): |
| | | |
| | | def __init__(self, input_dim, output_dim, quantize=0): |
| | | def __init__(self, input_dim, output_dim): |
| | | super(LinearTransform, self).__init__() |
| | | self.input_dim = input_dim |
| | | self.output_dim = output_dim |
| | | self.linear = nn.Linear(input_dim, output_dim, bias=False) |
| | | self.quantize = quantize |
| | | self.quant = torch.quantization.QuantStub() |
| | | self.dequant = torch.quantization.DeQuantStub() |
| | | |
| | | def forward(self, input): |
| | | if self.quantize: |
| | | output = self.quant(input) |
| | | else: |
| | | output = input |
| | | output = self.linear(output) |
| | | if self.quantize: |
| | | output = self.dequant(output) |
| | | output = self.linear(input) |
| | | |
| | | return output |
| | | |
| | | |
| | | class AffineTransform(nn.Module): |
| | | |
| | | def __init__(self, input_dim, output_dim, quantize=0): |
| | | def __init__(self, input_dim, output_dim): |
| | | super(AffineTransform, self).__init__() |
| | | self.input_dim = input_dim |
| | | self.output_dim = output_dim |
| | | self.quantize = quantize |
| | | self.linear = nn.Linear(input_dim, output_dim) |
| | | self.quant = torch.quantization.QuantStub() |
| | | self.dequant = torch.quantization.DeQuantStub() |
| | | |
| | | def forward(self, input): |
| | | if self.quantize: |
| | | output = self.quant(input) |
| | | else: |
| | | output = input |
| | | output = self.linear(output) |
| | | if self.quantize: |
| | | output = self.dequant(output) |
| | | output = self.linear(input) |
| | | |
| | | return output |
| | | |
| | | |
| | | class RectifiedLinear(nn.Module): |
| | | |
| | | def __init__(self, input_dim, output_dim): |
| | | super(RectifiedLinear, self).__init__() |
| | | self.dim = input_dim |
| | | self.relu = nn.ReLU() |
| | | self.dropout = nn.Dropout(0.1) |
| | | |
| | | def forward(self, input): |
| | | out = self.relu(input) |
| | | return out |
| | | |
| | | |
| | | class FSMNBlock(nn.Module): |
| | |
| | | rorder=None, |
| | | lstride=1, |
| | | rstride=1, |
| | | quantize=0 |
| | | ): |
| | | super(FSMNBlock, self).__init__() |
| | | |
| | |
| | | self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False) |
| | | else: |
| | | self.conv_right = None |
| | | self.quantize = quantize |
| | | self.quant = torch.quantization.QuantStub() |
| | | self.dequant = torch.quantization.DeQuantStub() |
| | | |
| | | def forward(self, input): |
| | | def forward(self, input: torch.Tensor, in_cache=None): |
| | | x = torch.unsqueeze(input, 1) |
| | | x_per = x.permute(0, 3, 2, 1) |
| | | |
| | | x_per = x.permute(0, 3, 2, 1) # B D T C |
| | | if in_cache is None: # offline |
| | | y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0]) |
| | | if self.quantize: |
| | | y_left = self.quant(y_left) |
| | | else: |
| | | y_left = torch.cat((in_cache, x_per), dim=2) |
| | | in_cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :] |
| | | y_left = self.conv_left(y_left) |
| | | if self.quantize: |
| | | y_left = self.dequant(y_left) |
| | | out = x_per + y_left |
| | | |
| | | if self.conv_right is not None: |
| | | # maybe need to check |
| | | y_right = F.pad(x_per, [0, 0, 0, self.rorder * self.rstride]) |
| | | y_right = y_right[:, :, self.rstride:, :] |
| | | if self.quantize: |
| | | y_right = self.quant(y_right) |
| | | y_right = self.conv_right(y_right) |
| | | if self.quantize: |
| | | y_right = self.dequant(y_right) |
| | | out += y_right |
| | | |
| | | out_per = out.permute(0, 3, 2, 1) |
| | | output = out_per.squeeze(1) |
| | | |
| | | return output |
| | | return output, in_cache |
| | | |
| | | |
| | | class RectifiedLinear(nn.Module): |
| | | |
| | | def __init__(self, input_dim, output_dim): |
| | | super(RectifiedLinear, self).__init__() |
| | | self.dim = input_dim |
| | | self.relu = nn.ReLU() |
| | | self.dropout = nn.Dropout(0.1) |
| | | |
| | | def forward(self, input): |
| | | out = self.relu(input) |
| | | # out = self.dropout(out) |
| | | return out |
| | | |
| | | |
| | | def _build_repeats( |
| | | fsmn_layers: int, |
| | | class BasicBlock(nn.Sequential): |
| | | def __init__(self, |
| | | linear_dim: int, |
| | | proj_dim: int, |
| | | lorder: int, |
| | | rorder: int, |
| | | lstride=1, |
| | | rstride=1, |
| | | lstride: int, |
| | | rstride: int, |
| | | stack_layer: int |
| | | ): |
| | | repeats = [ |
| | | nn.Sequential( |
| | | LinearTransform(linear_dim, proj_dim), |
| | | FSMNBlock(proj_dim, proj_dim, lorder, rorder, 1, 1), |
| | | AffineTransform(proj_dim, linear_dim), |
| | | RectifiedLinear(linear_dim, linear_dim)) |
| | | for i in range(fsmn_layers) |
| | | ] |
| | | super(BasicBlock, self).__init__() |
| | | self.lorder = lorder |
| | | self.rorder = rorder |
| | | self.lstride = lstride |
| | | self.rstride = rstride |
| | | self.stack_layer = stack_layer |
| | | self.linear = LinearTransform(linear_dim, proj_dim) |
| | | self.fsmn_block = FSMNBlock(proj_dim, proj_dim, lorder, rorder, lstride, rstride) |
| | | self.affine = AffineTransform(proj_dim, linear_dim) |
| | | self.relu = RectifiedLinear(linear_dim, linear_dim) |
| | | |
| | | return nn.Sequential(*repeats) |
| | | def forward(self, input: torch.Tensor, in_cache=None): |
| | | x1 = self.linear(input) # B T D |
| | | if in_cache is not None: # Dict[str, tensor.Tensor] |
| | | cache_layer_name = 'cache_layer_{}'.format(self.stack_layer) |
| | | if cache_layer_name not in in_cache: |
| | | in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1) |
| | | x2, in_cache[cache_layer_name] = self.fsmn_block(x1, in_cache[cache_layer_name]) |
| | | else: |
| | | x2, _ = self.fsmn_block(x1) |
| | | x3 = self.affine(x2) |
| | | x4 = self.relu(x3) |
| | | return x4, in_cache |
| | | |
| | | |
| | | class FsmnStack(nn.Sequential): |
| | | def __init__(self, *args): |
| | | super(FsmnStack, self).__init__(*args) |
| | | |
| | | def forward(self, input: torch.Tensor, in_cache=None): |
| | | x = input |
| | | for module in self._modules.values(): |
| | | x, in_cache = module(x, in_cache) |
| | | return x |
| | | |
| | | |
| | | ''' |
| | |
| | | rstride: int, |
| | | output_affine_dim: int, |
| | | output_dim: int, |
| | | streaming=False |
| | | ): |
| | | super(FSMN, self).__init__() |
| | | |
| | |
| | | self.fsmn_layers = fsmn_layers |
| | | self.linear_dim = linear_dim |
| | | self.proj_dim = proj_dim |
| | | self.lorder = lorder |
| | | self.rorder = rorder |
| | | self.lstride = lstride |
| | | self.rstride = rstride |
| | | self.output_affine_dim = output_affine_dim |
| | | self.output_dim = output_dim |
| | | self.in_cache_original = dict() if streaming else None |
| | | self.in_cache = copy.deepcopy(self.in_cache_original) |
| | | |
| | | self.in_linear1 = AffineTransform(input_dim, input_affine_dim) |
| | | self.in_linear2 = AffineTransform(input_affine_dim, linear_dim) |
| | | self.relu = RectifiedLinear(linear_dim, linear_dim) |
| | | |
| | | self.fsmn = _build_repeats(fsmn_layers, |
| | | linear_dim, |
| | | proj_dim, |
| | | lorder, rorder, |
| | | lstride, rstride) |
| | | |
| | | self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in |
| | | range(fsmn_layers)]) |
| | | self.out_linear1 = AffineTransform(linear_dim, output_affine_dim) |
| | | self.out_linear2 = AffineTransform(output_affine_dim, output_dim) |
| | | self.softmax = nn.Softmax(dim=-1) |
| | |
| | | def fuse_modules(self): |
| | | pass |
| | | |
| | | def cache_reset(self): |
| | | self.in_cache = copy.deepcopy(self.in_cache_original) |
| | | |
| | | def forward( |
| | | self, |
| | | input: torch.Tensor, |
| | | in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float) |
| | | ) -> torch.Tensor: |
| | | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: |
| | | """ |
| | | Args: |
| | | input (torch.Tensor): Input tensor (B, T, D) |
| | | in_cache(torhc.Tensor): (B, D, C), C is the accumulated cache size |
| | | in_cache: when in_cache is not None, the forward is in streaming. The type of in_cache is a dict, egs, |
| | | {'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame |
| | | """ |
| | | |
| | | x1 = self.in_linear1(input) |
| | | x2 = self.in_linear2(x1) |
| | | x3 = self.relu(x2) |
| | | x4 = self.fsmn(x3) |
| | | x4 = self.fsmn(x3, self.in_cache) # if in_cache is not None, self.fsmn is streaming's format, it will update automatically in self.fsmn |
| | | x5 = self.out_linear1(x4) |
| | | x6 = self.out_linear2(x5) |
| | | x7 = self.softmax(x6) |
| | | |
| | | return x7 |
| | | # return x6, in_cache |
| | | |
| | | |
| | | ''' |
| | |
| | | model_class = model_choices.get_class(args.model) |
| | | except AttributeError: |
| | | model_class = model_choices.get_class("e2evad") |
| | | model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf) |
| | | model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf, |
| | | streaming=args.encoder_conf.get('streaming', False)) |
| | | |
| | | return model |
| | | |