| | |
| | | |
| | | from funasr.models.decoder.abs_decoder import AbsDecoder |
| | | from funasr.modules.attention import MultiHeadedAttention |
| | | from funasr.modules.attention import CosineDistanceAttention |
| | | from funasr.modules.dynamic_conv import DynamicConvolution |
| | | from funasr.modules.dynamic_conv2d import DynamicConvolution2D |
| | | from funasr.modules.embedding import PositionalEncoding |
| | |
| | | normalize_before, |
| | | concat_after, |
| | | ), |
| | | ) |
| | | ) |
| | | |
| | | class BaseSAAsrTransformerDecoder(AbsDecoder, BatchScorerInterface): |
| | | |
| | | def __init__( |
| | | self, |
| | | vocab_size: int, |
| | | encoder_output_size: int, |
| | | spker_embedding_dim: int = 256, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | input_layer: str = "embed", |
| | | use_asr_output_layer: bool = True, |
| | | use_spk_output_layer: bool = True, |
| | | pos_enc_class=PositionalEncoding, |
| | | normalize_before: bool = True, |
| | | ): |
| | | assert check_argument_types() |
| | | super().__init__() |
| | | attention_dim = encoder_output_size |
| | | |
| | | if input_layer == "embed": |
| | | self.embed = torch.nn.Sequential( |
| | | torch.nn.Embedding(vocab_size, attention_dim), |
| | | pos_enc_class(attention_dim, positional_dropout_rate), |
| | | ) |
| | | elif input_layer == "linear": |
| | | self.embed = torch.nn.Sequential( |
| | | torch.nn.Linear(vocab_size, attention_dim), |
| | | torch.nn.LayerNorm(attention_dim), |
| | | torch.nn.Dropout(dropout_rate), |
| | | torch.nn.ReLU(), |
| | | pos_enc_class(attention_dim, positional_dropout_rate), |
| | | ) |
| | | else: |
| | | raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}") |
| | | |
| | | self.normalize_before = normalize_before |
| | | if self.normalize_before: |
| | | self.after_norm = LayerNorm(attention_dim) |
| | | if use_asr_output_layer: |
| | | self.asr_output_layer = torch.nn.Linear(attention_dim, vocab_size) |
| | | else: |
| | | self.asr_output_layer = None |
| | | |
| | | if use_spk_output_layer: |
| | | self.spk_output_layer = torch.nn.Linear(attention_dim, spker_embedding_dim) |
| | | else: |
| | | self.spk_output_layer = None |
| | | |
| | | self.cos_distance_att = CosineDistanceAttention() |
| | | |
| | | self.decoder1 = None |
| | | self.decoder2 = None |
| | | self.decoder3 = None |
| | | self.decoder4 = None |
| | | |
| | | def forward( |
| | | self, |
| | | asr_hs_pad: torch.Tensor, |
| | | spk_hs_pad: torch.Tensor, |
| | | hlens: torch.Tensor, |
| | | ys_in_pad: torch.Tensor, |
| | | ys_in_lens: torch.Tensor, |
| | | profile: torch.Tensor, |
| | | profile_lens: torch.Tensor, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | |
| | | tgt = ys_in_pad |
| | | # tgt_mask: (B, 1, L) |
| | | tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device) |
| | | # m: (1, L, L) |
| | | m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0) |
| | | # tgt_mask: (B, L, L) |
| | | tgt_mask = tgt_mask & m |
| | | |
| | | asr_memory = asr_hs_pad |
| | | spk_memory = spk_hs_pad |
| | | memory_mask = (~make_pad_mask(hlens))[:, None, :].to(asr_memory.device) |
| | | # Spk decoder |
| | | x = self.embed(tgt) |
| | | |
| | | x, tgt_mask, asr_memory, spk_memory, memory_mask, z = self.decoder1( |
| | | x, tgt_mask, asr_memory, spk_memory, memory_mask |
| | | ) |
| | | x, tgt_mask, spk_memory, memory_mask = self.decoder2( |
| | | x, tgt_mask, spk_memory, memory_mask |
| | | ) |
| | | if self.normalize_before: |
| | | x = self.after_norm(x) |
| | | if self.spk_output_layer is not None: |
| | | x = self.spk_output_layer(x) |
| | | dn, weights = self.cos_distance_att(x, profile, profile_lens) |
| | | # Asr decoder |
| | | x, tgt_mask, asr_memory, memory_mask = self.decoder3( |
| | | z, tgt_mask, asr_memory, memory_mask, dn |
| | | ) |
| | | x, tgt_mask, asr_memory, memory_mask = self.decoder4( |
| | | x, tgt_mask, asr_memory, memory_mask |
| | | ) |
| | | |
| | | if self.normalize_before: |
| | | x = self.after_norm(x) |
| | | if self.asr_output_layer is not None: |
| | | x = self.asr_output_layer(x) |
| | | |
| | | olens = tgt_mask.sum(1) |
| | | return x, weights, olens |
| | | |
| | | |
| | | def forward_one_step( |
| | | self, |
| | | tgt: torch.Tensor, |
| | | tgt_mask: torch.Tensor, |
| | | asr_memory: torch.Tensor, |
| | | spk_memory: torch.Tensor, |
| | | profile: torch.Tensor, |
| | | cache: List[torch.Tensor] = None, |
| | | ) -> Tuple[torch.Tensor, List[torch.Tensor]]: |
| | | |
| | | x = self.embed(tgt) |
| | | |
| | | if cache is None: |
| | | cache = [None] * (2 + len(self.decoder2) + len(self.decoder4)) |
| | | new_cache = [] |
| | | x, tgt_mask, asr_memory, spk_memory, _, z = self.decoder1( |
| | | x, tgt_mask, asr_memory, spk_memory, None, cache=cache[0] |
| | | ) |
| | | new_cache.append(x) |
| | | for c, decoder in zip(cache[1: len(self.decoder2) + 1], self.decoder2): |
| | | x, tgt_mask, spk_memory, _ = decoder( |
| | | x, tgt_mask, spk_memory, None, cache=c |
| | | ) |
| | | new_cache.append(x) |
| | | if self.normalize_before: |
| | | x = self.after_norm(x) |
| | | else: |
| | | x = x |
| | | if self.spk_output_layer is not None: |
| | | x = self.spk_output_layer(x) |
| | | dn, weights = self.cos_distance_att(x, profile, None) |
| | | |
| | | x, tgt_mask, asr_memory, _ = self.decoder3( |
| | | z, tgt_mask, asr_memory, None, dn, cache=cache[len(self.decoder2) + 1] |
| | | ) |
| | | new_cache.append(x) |
| | | |
| | | for c, decoder in zip(cache[len(self.decoder2) + 2: ], self.decoder4): |
| | | x, tgt_mask, asr_memory, _ = decoder( |
| | | x, tgt_mask, asr_memory, None, cache=c |
| | | ) |
| | | new_cache.append(x) |
| | | |
| | | if self.normalize_before: |
| | | y = self.after_norm(x[:, -1]) |
| | | else: |
| | | y = x[:, -1] |
| | | if self.asr_output_layer is not None: |
| | | y = torch.log_softmax(self.asr_output_layer(y), dim=-1) |
| | | |
| | | return y, weights, new_cache |
| | | |
| | | def score(self, ys, state, asr_enc, spk_enc, profile): |
| | | """Score.""" |
| | | ys_mask = subsequent_mask(len(ys), device=ys.device).unsqueeze(0) |
| | | logp, weights, state = self.forward_one_step( |
| | | ys.unsqueeze(0), ys_mask, asr_enc.unsqueeze(0), spk_enc.unsqueeze(0), profile.unsqueeze(0), cache=state |
| | | ) |
| | | return logp.squeeze(0), weights.squeeze(), state |
| | | |
| | | class SAAsrTransformerDecoder(BaseSAAsrTransformerDecoder): |
| | | def __init__( |
| | | self, |
| | | vocab_size: int, |
| | | encoder_output_size: int, |
| | | spker_embedding_dim: int = 256, |
| | | attention_heads: int = 4, |
| | | linear_units: int = 2048, |
| | | asr_num_blocks: int = 6, |
| | | spk_num_blocks: int = 3, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | self_attention_dropout_rate: float = 0.0, |
| | | src_attention_dropout_rate: float = 0.0, |
| | | input_layer: str = "embed", |
| | | use_asr_output_layer: bool = True, |
| | | use_spk_output_layer: bool = True, |
| | | pos_enc_class=PositionalEncoding, |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | ): |
| | | assert check_argument_types() |
| | | super().__init__( |
| | | vocab_size=vocab_size, |
| | | encoder_output_size=encoder_output_size, |
| | | spker_embedding_dim=spker_embedding_dim, |
| | | dropout_rate=dropout_rate, |
| | | positional_dropout_rate=positional_dropout_rate, |
| | | input_layer=input_layer, |
| | | use_asr_output_layer=use_asr_output_layer, |
| | | use_spk_output_layer=use_spk_output_layer, |
| | | pos_enc_class=pos_enc_class, |
| | | normalize_before=normalize_before, |
| | | ) |
| | | |
| | | attention_dim = encoder_output_size |
| | | |
| | | self.decoder1 = SpeakerAttributeSpkDecoderFirstLayer( |
| | | attention_dim, |
| | | MultiHeadedAttention( |
| | | attention_heads, attention_dim, self_attention_dropout_rate |
| | | ), |
| | | MultiHeadedAttention( |
| | | attention_heads, attention_dim, src_attention_dropout_rate |
| | | ), |
| | | PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), |
| | | dropout_rate, |
| | | normalize_before, |
| | | concat_after, |
| | | ) |
| | | self.decoder2 = repeat( |
| | | spk_num_blocks - 1, |
| | | lambda lnum: DecoderLayer( |
| | | attention_dim, |
| | | MultiHeadedAttention( |
| | | attention_heads, attention_dim, self_attention_dropout_rate |
| | | ), |
| | | MultiHeadedAttention( |
| | | attention_heads, attention_dim, src_attention_dropout_rate |
| | | ), |
| | | PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), |
| | | dropout_rate, |
| | | normalize_before, |
| | | concat_after, |
| | | ), |
| | | ) |
| | | |
| | | |
| | | self.decoder3 = SpeakerAttributeAsrDecoderFirstLayer( |
| | | attention_dim, |
| | | spker_embedding_dim, |
| | | MultiHeadedAttention( |
| | | attention_heads, attention_dim, src_attention_dropout_rate |
| | | ), |
| | | PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), |
| | | dropout_rate, |
| | | normalize_before, |
| | | concat_after, |
| | | ) |
| | | self.decoder4 = repeat( |
| | | asr_num_blocks - 1, |
| | | lambda lnum: DecoderLayer( |
| | | attention_dim, |
| | | MultiHeadedAttention( |
| | | attention_heads, attention_dim, self_attention_dropout_rate |
| | | ), |
| | | MultiHeadedAttention( |
| | | attention_heads, attention_dim, src_attention_dropout_rate |
| | | ), |
| | | PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), |
| | | dropout_rate, |
| | | normalize_before, |
| | | concat_after, |
| | | ), |
| | | ) |
| | | |
| | | class SpeakerAttributeSpkDecoderFirstLayer(nn.Module): |
| | | |
| | | def __init__( |
| | | self, |
| | | size, |
| | | self_attn, |
| | | src_attn, |
| | | feed_forward, |
| | | dropout_rate, |
| | | normalize_before=True, |
| | | concat_after=False, |
| | | ): |
| | | """Construct an DecoderLayer object.""" |
| | | super(SpeakerAttributeSpkDecoderFirstLayer, self).__init__() |
| | | self.size = size |
| | | self.self_attn = self_attn |
| | | self.src_attn = src_attn |
| | | self.feed_forward = feed_forward |
| | | self.norm1 = LayerNorm(size) |
| | | self.norm2 = LayerNorm(size) |
| | | self.dropout = nn.Dropout(dropout_rate) |
| | | self.normalize_before = normalize_before |
| | | self.concat_after = concat_after |
| | | if self.concat_after: |
| | | self.concat_linear1 = nn.Linear(size + size, size) |
| | | self.concat_linear2 = nn.Linear(size + size, size) |
| | | |
| | | def forward(self, tgt, tgt_mask, asr_memory, spk_memory, memory_mask, cache=None): |
| | | |
| | | residual = tgt |
| | | if self.normalize_before: |
| | | tgt = self.norm1(tgt) |
| | | |
| | | if cache is None: |
| | | tgt_q = tgt |
| | | tgt_q_mask = tgt_mask |
| | | else: |
| | | # compute only the last frame query keeping dim: max_time_out -> 1 |
| | | assert cache.shape == ( |
| | | tgt.shape[0], |
| | | tgt.shape[1] - 1, |
| | | self.size, |
| | | ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}" |
| | | tgt_q = tgt[:, -1:, :] |
| | | residual = residual[:, -1:, :] |
| | | tgt_q_mask = None |
| | | if tgt_mask is not None: |
| | | tgt_q_mask = tgt_mask[:, -1:, :] |
| | | |
| | | if self.concat_after: |
| | | tgt_concat = torch.cat( |
| | | (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1 |
| | | ) |
| | | x = residual + self.concat_linear1(tgt_concat) |
| | | else: |
| | | x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)) |
| | | if not self.normalize_before: |
| | | x = self.norm1(x) |
| | | z = x |
| | | |
| | | residual = x |
| | | if self.normalize_before: |
| | | x = self.norm1(x) |
| | | |
| | | skip = self.src_attn(x, asr_memory, spk_memory, memory_mask) |
| | | |
| | | if self.concat_after: |
| | | x_concat = torch.cat( |
| | | (x, skip), dim=-1 |
| | | ) |
| | | x = residual + self.concat_linear2(x_concat) |
| | | else: |
| | | x = residual + self.dropout(skip) |
| | | if not self.normalize_before: |
| | | x = self.norm1(x) |
| | | |
| | | residual = x |
| | | if self.normalize_before: |
| | | x = self.norm2(x) |
| | | x = residual + self.dropout(self.feed_forward(x)) |
| | | if not self.normalize_before: |
| | | x = self.norm2(x) |
| | | |
| | | if cache is not None: |
| | | x = torch.cat([cache, x], dim=1) |
| | | |
| | | return x, tgt_mask, asr_memory, spk_memory, memory_mask, z |
| | | |
| | | class SpeakerAttributeAsrDecoderFirstLayer(nn.Module): |
| | | |
| | | def __init__( |
| | | self, |
| | | size, |
| | | d_size, |
| | | src_attn, |
| | | feed_forward, |
| | | dropout_rate, |
| | | normalize_before=True, |
| | | concat_after=False, |
| | | ): |
| | | """Construct an DecoderLayer object.""" |
| | | super(SpeakerAttributeAsrDecoderFirstLayer, self).__init__() |
| | | self.size = size |
| | | self.src_attn = src_attn |
| | | self.feed_forward = feed_forward |
| | | self.norm1 = LayerNorm(size) |
| | | self.norm2 = LayerNorm(size) |
| | | self.norm3 = LayerNorm(size) |
| | | self.dropout = nn.Dropout(dropout_rate) |
| | | self.normalize_before = normalize_before |
| | | self.concat_after = concat_after |
| | | self.spk_linear = nn.Linear(d_size, size, bias=False) |
| | | if self.concat_after: |
| | | self.concat_linear1 = nn.Linear(size + size, size) |
| | | self.concat_linear2 = nn.Linear(size + size, size) |
| | | |
| | | def forward(self, tgt, tgt_mask, memory, memory_mask, dn, cache=None): |
| | | |
| | | residual = tgt |
| | | if self.normalize_before: |
| | | tgt = self.norm1(tgt) |
| | | |
| | | if cache is None: |
| | | tgt_q = tgt |
| | | tgt_q_mask = tgt_mask |
| | | else: |
| | | |
| | | tgt_q = tgt[:, -1:, :] |
| | | residual = residual[:, -1:, :] |
| | | tgt_q_mask = None |
| | | if tgt_mask is not None: |
| | | tgt_q_mask = tgt_mask[:, -1:, :] |
| | | |
| | | x = tgt_q |
| | | if self.normalize_before: |
| | | x = self.norm2(x) |
| | | if self.concat_after: |
| | | x_concat = torch.cat( |
| | | (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1 |
| | | ) |
| | | x = residual + self.concat_linear2(x_concat) |
| | | else: |
| | | x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask)) |
| | | if not self.normalize_before: |
| | | x = self.norm2(x) |
| | | residual = x |
| | | |
| | | if dn!=None: |
| | | x = x + self.spk_linear(dn) |
| | | if self.normalize_before: |
| | | x = self.norm3(x) |
| | | |
| | | x = residual + self.dropout(self.feed_forward(x)) |
| | | if not self.normalize_before: |
| | | x = self.norm3(x) |
| | | |
| | | if cache is not None: |
| | | x = torch.cat([cache, x], dim=1) |
| | | |
| | | return x, tgt_mask, memory, memory_mask |