import torch
|
from torch import nn
|
|
from funasr.modules.layer_norm import LayerNorm
|
|
|
class SpeakerAttributeSpkDecoderFirstLayer(nn.Module):
|
|
def __init__(
|
self,
|
size,
|
self_attn,
|
src_attn,
|
feed_forward,
|
dropout_rate,
|
normalize_before=True,
|
concat_after=False,
|
):
|
"""Construct an DecoderLayer object."""
|
super(SpeakerAttributeSpkDecoderFirstLayer, self).__init__()
|
self.size = size
|
self.self_attn = self_attn
|
self.src_attn = src_attn
|
self.feed_forward = feed_forward
|
self.norm1 = LayerNorm(size)
|
self.norm2 = LayerNorm(size)
|
self.dropout = nn.Dropout(dropout_rate)
|
self.normalize_before = normalize_before
|
self.concat_after = concat_after
|
if self.concat_after:
|
self.concat_linear1 = nn.Linear(size + size, size)
|
self.concat_linear2 = nn.Linear(size + size, size)
|
|
def forward(self, tgt, tgt_mask, asr_memory, spk_memory, memory_mask, cache=None):
|
|
residual = tgt
|
if self.normalize_before:
|
tgt = self.norm1(tgt)
|
|
if cache is None:
|
tgt_q = tgt
|
tgt_q_mask = tgt_mask
|
else:
|
# compute only the last frame query keeping dim: max_time_out -> 1
|
assert cache.shape == (
|
tgt.shape[0],
|
tgt.shape[1] - 1,
|
self.size,
|
), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
|
tgt_q = tgt[:, -1:, :]
|
residual = residual[:, -1:, :]
|
tgt_q_mask = None
|
if tgt_mask is not None:
|
tgt_q_mask = tgt_mask[:, -1:, :]
|
|
if self.concat_after:
|
tgt_concat = torch.cat(
|
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
|
)
|
x = residual + self.concat_linear1(tgt_concat)
|
else:
|
x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
|
if not self.normalize_before:
|
x = self.norm1(x)
|
z = x
|
|
residual = x
|
if self.normalize_before:
|
x = self.norm1(x)
|
|
skip = self.src_attn(x, asr_memory, spk_memory, memory_mask)
|
|
if self.concat_after:
|
x_concat = torch.cat(
|
(x, skip), dim=-1
|
)
|
x = residual + self.concat_linear2(x_concat)
|
else:
|
x = residual + self.dropout(skip)
|
if not self.normalize_before:
|
x = self.norm1(x)
|
|
residual = x
|
if self.normalize_before:
|
x = self.norm2(x)
|
x = residual + self.dropout(self.feed_forward(x))
|
if not self.normalize_before:
|
x = self.norm2(x)
|
|
if cache is not None:
|
x = torch.cat([cache, x], dim=1)
|
|
return x, tgt_mask, asr_memory, spk_memory, memory_mask, z
|
|
class SpeakerAttributeAsrDecoderFirstLayer(nn.Module):
|
|
def __init__(
|
self,
|
size,
|
d_size,
|
src_attn,
|
feed_forward,
|
dropout_rate,
|
normalize_before=True,
|
concat_after=False,
|
):
|
"""Construct an DecoderLayer object."""
|
super(SpeakerAttributeAsrDecoderFirstLayer, self).__init__()
|
self.size = size
|
self.src_attn = src_attn
|
self.feed_forward = feed_forward
|
self.norm1 = LayerNorm(size)
|
self.norm2 = LayerNorm(size)
|
self.norm3 = LayerNorm(size)
|
self.dropout = nn.Dropout(dropout_rate)
|
self.normalize_before = normalize_before
|
self.concat_after = concat_after
|
self.spk_linear = nn.Linear(d_size, size, bias=False)
|
if self.concat_after:
|
self.concat_linear1 = nn.Linear(size + size, size)
|
self.concat_linear2 = nn.Linear(size + size, size)
|
|
def forward(self, tgt, tgt_mask, memory, memory_mask, dn, cache=None):
|
|
residual = tgt
|
if self.normalize_before:
|
tgt = self.norm1(tgt)
|
|
if cache is None:
|
tgt_q = tgt
|
tgt_q_mask = tgt_mask
|
else:
|
|
tgt_q = tgt[:, -1:, :]
|
residual = residual[:, -1:, :]
|
tgt_q_mask = None
|
if tgt_mask is not None:
|
tgt_q_mask = tgt_mask[:, -1:, :]
|
|
x = tgt_q
|
if self.normalize_before:
|
x = self.norm2(x)
|
if self.concat_after:
|
x_concat = torch.cat(
|
(x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
|
)
|
x = residual + self.concat_linear2(x_concat)
|
else:
|
x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
|
if not self.normalize_before:
|
x = self.norm2(x)
|
residual = x
|
|
if dn!=None:
|
x = x + self.spk_linear(dn)
|
if self.normalize_before:
|
x = self.norm3(x)
|
|
x = residual + self.dropout(self.feed_forward(x))
|
if not self.normalize_before:
|
x = self.norm3(x)
|
|
if cache is not None:
|
x = torch.cat([cache, x], dim=1)
|
|
return x, tgt_mask, memory, memory_mask
|