| | |
| | | from funasr.train_utils.device_funcs import to_device |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask |
| | | from funasr.models.sanm.attention import MultiHeadedAttention, MultiHeadedAttentionSANM |
| | | from funasr.models.transformer.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder |
| | | from funasr.models.transformer.embedding import ( |
| | | SinusoidalPositionEncoder, |
| | | StreamSinusoidalPositionEncoder, |
| | | ) |
| | | from funasr.models.transformer.layer_norm import LayerNorm |
| | | from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear |
| | | from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d |
| | |
| | | from funasr.models.ctc.ctc import CTC |
| | | |
| | | from funasr.register import tables |
| | | |
| | | |
| | | class EncoderLayerSANM(nn.Module): |
| | | def __init__( |
| | |
| | | self.stochastic_depth_rate = stochastic_depth_rate |
| | | self.dropout_rate = dropout_rate |
| | | |
| | | def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None): |
| | | def forward(self, x, mask, cache=None, mask_shift_chunk=None, mask_att_chunk_encoder=None): |
| | | """Compute encoded features. |
| | | |
| | | Args: |
| | |
| | | x = self.norm1(x) |
| | | |
| | | if self.concat_after: |
| | | x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1) |
| | | x_concat = torch.cat( |
| | | ( |
| | | x, |
| | | self.self_attn( |
| | | x, |
| | | mask, |
| | | mask_shift_chunk=mask_shift_chunk, |
| | | mask_att_chunk_encoder=mask_att_chunk_encoder, |
| | | ), |
| | | ), |
| | | dim=-1, |
| | | ) |
| | | if self.in_size == self.size: |
| | | x = residual + stoch_layer_coeff * self.concat_linear(x_concat) |
| | | else: |
| | |
| | | else: |
| | | if self.in_size == self.size: |
| | | x = residual + stoch_layer_coeff * self.dropout( |
| | | self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder) |
| | | self.self_attn( |
| | | x, |
| | | mask, |
| | | mask_shift_chunk=mask_shift_chunk, |
| | | mask_att_chunk_encoder=mask_att_chunk_encoder, |
| | | ) |
| | | ) |
| | | else: |
| | | x = stoch_layer_coeff * self.dropout( |
| | | self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder) |
| | | self.self_attn( |
| | | x, |
| | | mask, |
| | | mask_shift_chunk=mask_shift_chunk, |
| | | mask_att_chunk_encoder=mask_att_chunk_encoder, |
| | | ) |
| | | ) |
| | | if not self.normalize_before: |
| | | x = self.norm1(x) |
| | |
| | | if not self.normalize_before: |
| | | x = self.norm2(x) |
| | | |
| | | return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder |
| | | return x, mask, cache, mask_shift_chunk, mask_att_chunk_encoder |
| | | |
| | | def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0): |
| | | """Compute encoded features. |
| | |
| | | |
| | | return x, cache |
| | | |
| | | |
| | | @tables.register("encoder_classes", "SANMEncoder") |
| | | class SANMEncoder(nn.Module): |
| | | """ |
| | |
| | | padding_idx: int = -1, |
| | | interctc_layer_idx: List[int] = [], |
| | | interctc_use_conditioning: bool = False, |
| | | kernel_size : int = 11, |
| | | sanm_shfit : int = 0, |
| | | kernel_size: int = 11, |
| | | sanm_shift: int = 0, |
| | | lora_list: List[str] = None, |
| | | lora_rank: int = 8, |
| | | lora_alpha: int = 16, |
| | |
| | | output_size, |
| | | attention_dropout_rate, |
| | | kernel_size, |
| | | sanm_shfit, |
| | | sanm_shift, |
| | | lora_list, |
| | | lora_rank, |
| | | lora_alpha, |
| | |
| | | output_size, |
| | | attention_dropout_rate, |
| | | kernel_size, |
| | | sanm_shfit, |
| | | sanm_shift, |
| | | lora_list, |
| | | lora_rank, |
| | | lora_alpha, |
| | |
| | | ) |
| | | |
| | | self.encoders = repeat( |
| | | num_blocks-1, |
| | | num_blocks - 1, |
| | | lambda lnum: EncoderLayerSANM( |
| | | output_size, |
| | | output_size, |
| | |
| | | position embedded tensor and mask |
| | | """ |
| | | masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) |
| | | xs_pad = xs_pad * self.output_size()**0.5 |
| | | xs_pad = xs_pad * self.output_size() ** 0.5 |
| | | if self.embed is None: |
| | | xs_pad = xs_pad |
| | | elif ( |
| | |
| | | return feats |
| | | cache["feats"] = to_device(cache["feats"], device=feats.device) |
| | | overlap_feats = torch.cat((cache["feats"], feats), dim=1) |
| | | cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :] |
| | | cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]) :, :] |
| | | return overlap_feats |
| | | |
| | | def forward_chunk(self, |
| | | xs_pad: torch.Tensor, |
| | | ilens: torch.Tensor, |
| | | cache: dict = None, |
| | | ctc: CTC = None, |
| | | ): |
| | | def forward_chunk( |
| | | self, |
| | | xs_pad: torch.Tensor, |
| | | ilens: torch.Tensor, |
| | | cache: dict = None, |
| | | ctc: CTC = None, |
| | | ): |
| | | xs_pad *= self.output_size() ** 0.5 |
| | | if self.embed is None: |
| | | xs_pad = xs_pad |
| | |
| | | return (xs_pad, intermediate_outs), None, None |
| | | return xs_pad, ilens, None |
| | | |
| | | |
| | | class EncoderLayerSANMExport(nn.Module): |
| | | def __init__( |
| | | self, |
| | |
| | | |
| | | return x, mask |
| | | |
| | | |
| | | @tables.register("encoder_classes", "SANMEncoderChunkOptExport") |
| | | @tables.register("encoder_classes", "SANMEncoderExport") |
| | | class SANMEncoderExport(nn.Module): |
| | |
| | | model, |
| | | max_seq_len=512, |
| | | feats_dim=560, |
| | | model_name='encoder', |
| | | model_name="encoder", |
| | | onnx: bool = True, |
| | | ctc_linear: nn.Module = None, |
| | | ): |
| | | super().__init__() |
| | | self.embed = model.embed |
| | |
| | | self.feats_dim = feats_dim |
| | | self._output_size = model._output_size |
| | | |
| | | |
| | | from funasr.utils.torch_function import sequence_mask |
| | | |
| | | |
| | | self.make_pad_mask = sequence_mask(max_seq_len, flip=False) |
| | | |
| | | |
| | | from funasr.models.sanm.attention import MultiHeadedAttentionSANMExport |
| | | if hasattr(model, 'encoders0'): |
| | | |
| | | if hasattr(model, "encoders0"): |
| | | for i, d in enumerate(self.model.encoders0): |
| | | if isinstance(d.self_attn, MultiHeadedAttentionSANM): |
| | | d.self_attn = MultiHeadedAttentionSANMExport(d.self_attn) |
| | | self.model.encoders0[i] = EncoderLayerSANMExport(d) |
| | | |
| | | |
| | | for i, d in enumerate(self.model.encoders): |
| | | if isinstance(d.self_attn, MultiHeadedAttentionSANM): |
| | | d.self_attn = MultiHeadedAttentionSANMExport(d.self_attn) |
| | | self.model.encoders[i] = EncoderLayerSANMExport(d) |
| | | |
| | | |
| | | self.model_name = model_name |
| | | self.num_heads = model.encoders[0].self_attn.h |
| | | self.hidden_size = model.encoders[0].self_attn.linear_out.out_features |
| | | |
| | | |
| | | self.ctc_linear = ctc_linear |
| | | |
| | | def prepare_mask(self, mask): |
| | | mask_3d_btd = mask[:, :, None] |
| | | if len(mask.shape) == 2: |
| | |
| | | elif len(mask.shape) == 3: |
| | | mask_4d_bhlt = 1 - mask[:, None, :] |
| | | mask_4d_bhlt = mask_4d_bhlt * -10000.0 |
| | | |
| | | |
| | | return mask_3d_btd, mask_4d_bhlt |
| | | |
| | | def forward(self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | online: bool = False |
| | | ): |
| | | |
| | | def forward(self, speech: torch.Tensor, speech_lengths: torch.Tensor, online: bool = False): |
| | | if not online: |
| | | speech = speech * self._output_size ** 0.5 |
| | | speech = speech * self._output_size**0.5 |
| | | |
| | | mask = self.make_pad_mask(speech_lengths) |
| | | mask = self.prepare_mask(mask) |
| | | if self.embed is None: |
| | | xs_pad = speech |
| | | else: |
| | | xs_pad = self.embed(speech) |
| | | |
| | | |
| | | encoder_outs = self.model.encoders0(xs_pad, mask) |
| | | xs_pad, masks = encoder_outs[0], encoder_outs[1] |
| | | |
| | | |
| | | encoder_outs = self.model.encoders(xs_pad, mask) |
| | | xs_pad, masks = encoder_outs[0], encoder_outs[1] |
| | | |
| | | |
| | | xs_pad = self.model.after_norm(xs_pad) |
| | | |
| | | |
| | | if self.ctc_linear is not None: |
| | | xs_pad = self.ctc_linear(xs_pad) |
| | | xs_pad = F.softmax(xs_pad, dim=2) |
| | | |
| | | return xs_pad, speech_lengths |
| | | |
| | | |
| | | def get_output_size(self): |
| | | return self.model.encoders[0].size |
| | | |
| | | |
| | | def get_dummy_inputs(self): |
| | | feats = torch.randn(1, 100, self.feats_dim) |
| | | return (feats) |
| | | |
| | | return feats |
| | | |
| | | def get_input_names(self): |
| | | return ['feats'] |
| | | |
| | | return ["feats"] |
| | | |
| | | def get_output_names(self): |
| | | return ['encoder_out', 'encoder_out_lens', 'predictor_weight'] |
| | | |
| | | return ["encoder_out", "encoder_out_lens", "predictor_weight"] |
| | | |
| | | def get_dynamic_axes(self): |
| | | return { |
| | | 'feats': { |
| | | 1: 'feats_length' |
| | | }, |
| | | 'encoder_out': { |
| | | 1: 'enc_out_length' |
| | | }, |
| | | 'predictor_weight': { |
| | | 1: 'pre_out_length' |
| | | } |
| | | |
| | | "feats": {1: "feats_length"}, |
| | | "encoder_out": {1: "enc_out_length"}, |
| | | "predictor_weight": {1: "pre_out_length"}, |
| | | } |
| | | |