| | |
| | | from funasr.train_utils.device_funcs import to_device |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask |
| | | from funasr.models.sanm.attention import MultiHeadedAttention, MultiHeadedAttentionSANM |
| | | from funasr.models.transformer.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder |
| | | from funasr.models.transformer.embedding import ( |
| | | SinusoidalPositionEncoder, |
| | | StreamSinusoidalPositionEncoder, |
| | | ) |
| | | from funasr.models.transformer.layer_norm import LayerNorm |
| | | from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear |
| | | from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d |
| | |
| | | from funasr.models.ctc.ctc import CTC |
| | | |
| | | from funasr.register import tables |
| | | |
| | | |
| | | class EncoderLayerSANM(nn.Module): |
| | | def __init__( |
| | |
| | | x = self.norm1(x) |
| | | |
| | | if self.concat_after: |
| | | x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1) |
| | | x_concat = torch.cat( |
| | | ( |
| | | x, |
| | | self.self_attn( |
| | | x, |
| | | mask, |
| | | mask_shfit_chunk=mask_shfit_chunk, |
| | | mask_att_chunk_encoder=mask_att_chunk_encoder, |
| | | ), |
| | | ), |
| | | dim=-1, |
| | | ) |
| | | if self.in_size == self.size: |
| | | x = residual + stoch_layer_coeff * self.concat_linear(x_concat) |
| | | else: |
| | |
| | | else: |
| | | if self.in_size == self.size: |
| | | x = residual + stoch_layer_coeff * self.dropout( |
| | | self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder) |
| | | self.self_attn( |
| | | x, |
| | | mask, |
| | | mask_shfit_chunk=mask_shfit_chunk, |
| | | mask_att_chunk_encoder=mask_att_chunk_encoder, |
| | | ) |
| | | ) |
| | | else: |
| | | x = stoch_layer_coeff * self.dropout( |
| | | self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder) |
| | | self.self_attn( |
| | | x, |
| | | mask, |
| | | mask_shfit_chunk=mask_shfit_chunk, |
| | | mask_att_chunk_encoder=mask_att_chunk_encoder, |
| | | ) |
| | | ) |
| | | if not self.normalize_before: |
| | | x = self.norm1(x) |
| | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | input_size: int, |
| | | output_size: int = 256, |
| | | attention_heads: int = 4, |
| | | linear_units: int = 2048, |
| | | num_blocks: int = 6, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | attention_dropout_rate: float = 0.0, |
| | | input_layer: Optional[str] = "conv2d", |
| | | pos_enc_class=SinusoidalPositionEncoder, |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | positionwise_layer_type: str = "linear", |
| | | positionwise_conv_kernel_size: int = 1, |
| | | padding_idx: int = -1, |
| | | interctc_layer_idx: List[int] = [], |
| | | interctc_use_conditioning: bool = False, |
| | | kernel_size: int = 11, |
| | | sanm_shfit: int = 0, |
| | | selfattention_layer_type: str = "sanm", |
| | | chunk_size: Union[int, Sequence[int]] = (16,), |
| | | stride: Union[int, Sequence[int]] = (10,), |
| | | pad_left: Union[int, Sequence[int]] = (0,), |
| | | encoder_att_look_back_factor: Union[int, Sequence[int]] = (1,), |
| | | decoder_att_look_back_factor: Union[int, Sequence[int]] = (1,), |
| | | tf2torch_tensor_name_prefix_torch: str = "encoder", |
| | | tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder", |
| | | self, |
| | | input_size: int, |
| | | output_size: int = 256, |
| | | attention_heads: int = 4, |
| | | linear_units: int = 2048, |
| | | num_blocks: int = 6, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | attention_dropout_rate: float = 0.0, |
| | | input_layer: Optional[str] = "conv2d", |
| | | pos_enc_class=SinusoidalPositionEncoder, |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | positionwise_layer_type: str = "linear", |
| | | positionwise_conv_kernel_size: int = 1, |
| | | padding_idx: int = -1, |
| | | interctc_layer_idx: List[int] = [], |
| | | interctc_use_conditioning: bool = False, |
| | | kernel_size: int = 11, |
| | | sanm_shfit: int = 0, |
| | | selfattention_layer_type: str = "sanm", |
| | | chunk_size: Union[int, Sequence[int]] = (16,), |
| | | stride: Union[int, Sequence[int]] = (10,), |
| | | pad_left: Union[int, Sequence[int]] = (0,), |
| | | encoder_att_look_back_factor: Union[int, Sequence[int]] = (1,), |
| | | decoder_att_look_back_factor: Union[int, Sequence[int]] = (1,), |
| | | tf2torch_tensor_name_prefix_torch: str = "encoder", |
| | | tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder", |
| | | ): |
| | | super().__init__() |
| | | self._output_size = output_size |
| | |
| | | return self._output_size |
| | | |
| | | def forward( |
| | | self, |
| | | xs_pad: torch.Tensor, |
| | | ilens: torch.Tensor, |
| | | prev_states: torch.Tensor = None, |
| | | ctc: CTC = None, |
| | | ind: int = 0, |
| | | self, |
| | | xs_pad: torch.Tensor, |
| | | ilens: torch.Tensor, |
| | | prev_states: torch.Tensor = None, |
| | | ctc: CTC = None, |
| | | ind: int = 0, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: |
| | | """Embed positions in tensor. |
| | | |
| | |
| | | if self.embed is None: |
| | | xs_pad = xs_pad |
| | | elif ( |
| | | isinstance(self.embed, Conv2dSubsampling) |
| | | or isinstance(self.embed, Conv2dSubsampling2) |
| | | or isinstance(self.embed, Conv2dSubsampling6) |
| | | or isinstance(self.embed, Conv2dSubsampling8) |
| | | isinstance(self.embed, Conv2dSubsampling) |
| | | or isinstance(self.embed, Conv2dSubsampling2) |
| | | or isinstance(self.embed, Conv2dSubsampling6) |
| | | or isinstance(self.embed, Conv2dSubsampling8) |
| | | ): |
| | | short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) |
| | | if short_status: |
| | |
| | | chunk_outs = self.overlap_chunk_cls.gen_chunk_mask(ilens, ind) |
| | | xs_pad, ilens = self.overlap_chunk_cls.split_chunk(xs_pad, ilens, chunk_outs=chunk_outs) |
| | | masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) |
| | | mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk(chunk_outs, xs_pad.device, xs_pad.size(0), |
| | | dtype=xs_pad.dtype) |
| | | mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(chunk_outs, xs_pad.device, |
| | | xs_pad.size(0), |
| | | dtype=xs_pad.dtype) |
| | | mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk( |
| | | chunk_outs, xs_pad.device, xs_pad.size(0), dtype=xs_pad.dtype |
| | | ) |
| | | mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder( |
| | | chunk_outs, xs_pad.device, xs_pad.size(0), dtype=xs_pad.dtype |
| | | ) |
| | | |
| | | encoder_outs = self.encoders0(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder) |
| | | xs_pad, masks = encoder_outs[0], encoder_outs[1] |
| | | intermediate_outs = [] |
| | | if len(self.interctc_layer_idx) == 0: |
| | | encoder_outs = self.encoders(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder) |
| | | encoder_outs = self.encoders( |
| | | xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder |
| | | ) |
| | | xs_pad, masks = encoder_outs[0], encoder_outs[1] |
| | | else: |
| | | for layer_idx, encoder_layer in enumerate(self.encoders): |
| | | encoder_outs = encoder_layer(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder) |
| | | encoder_outs = encoder_layer( |
| | | xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder |
| | | ) |
| | | xs_pad, masks = encoder_outs[0], encoder_outs[1] |
| | | if layer_idx + 1 in self.interctc_layer_idx: |
| | | encoder_out = xs_pad |
| | |
| | | return feats |
| | | cache["feats"] = to_device(cache["feats"], device=feats.device) |
| | | overlap_feats = torch.cat((cache["feats"], feats), dim=1) |
| | | cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :] |
| | | cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]) :, :] |
| | | return overlap_feats |
| | | |
| | | def forward_chunk(self, |
| | | xs_pad: torch.Tensor, |
| | | ilens: torch.Tensor, |
| | | cache: dict = None, |
| | | **kwargs, |
| | | ): |
| | | def forward_chunk( |
| | | self, |
| | | xs_pad: torch.Tensor, |
| | | ilens: torch.Tensor, |
| | | cache: dict = None, |
| | | **kwargs, |
| | | ): |
| | | is_final = kwargs.get("is_final", False) |
| | | xs_pad *= self.output_size() ** 0.5 |
| | | if self.embed is None: |
| | |
| | | new_cache = cache["opt"] |
| | | |
| | | for layer_idx, encoder_layer in enumerate(self.encoders0): |
| | | encoder_outs = encoder_layer.forward_chunk(xs_pad, new_cache[layer_idx], cache["chunk_size"], cache["encoder_chunk_look_back"]) |
| | | encoder_outs = encoder_layer.forward_chunk( |
| | | xs_pad, new_cache[layer_idx], cache["chunk_size"], cache["encoder_chunk_look_back"] |
| | | ) |
| | | xs_pad, new_cache[0] = encoder_outs[0], encoder_outs[1] |
| | | |
| | | for layer_idx, encoder_layer in enumerate(self.encoders): |
| | | encoder_outs = encoder_layer.forward_chunk(xs_pad, new_cache[layer_idx+len(self.encoders0)], cache["chunk_size"], cache["encoder_chunk_look_back"]) |
| | | xs_pad, new_cache[layer_idx+len(self.encoders0)] = encoder_outs[0], encoder_outs[1] |
| | | encoder_outs = encoder_layer.forward_chunk( |
| | | xs_pad, |
| | | new_cache[layer_idx + len(self.encoders0)], |
| | | cache["chunk_size"], |
| | | cache["encoder_chunk_look_back"], |
| | | ) |
| | | xs_pad, new_cache[layer_idx + len(self.encoders0)] = encoder_outs[0], encoder_outs[1] |
| | | |
| | | if self.normalize_before: |
| | | xs_pad = self.after_norm(xs_pad) |
| | |
| | | cache["opt"] = new_cache |
| | | |
| | | return xs_pad, ilens, None |
| | | |