Zhihao Du
2023-03-16 38de2af5bf9976d2f14f087d9a0d31991daf6783
funasr/export/models/modules/multihead_att.py
@@ -64,6 +64,21 @@
        return self.linear_out(context_layer)  # (batch, time1, d_model)
def preprocess_for_attn(x, mask, cache, pad_fn):
    x = x * mask
    x = x.transpose(1, 2)
    if cache is None:
        x = pad_fn(x)
    else:
        x = torch.cat((cache[:, :, 1:], x), dim=2)
        cache = x
    return x, cache
import torch.fx
torch.fx.wrap('preprocess_for_attn')
class MultiHeadedAttentionSANMDecoder(nn.Module):
    def __init__(self, model):
        super().__init__()
@@ -73,16 +88,7 @@
        self.attn = None
    def forward(self, inputs, mask, cache=None):
        # b, t, d = inputs.size()
        # mask = torch.reshape(mask, (b, -1, 1))
        inputs = inputs * mask
        x = inputs.transpose(1, 2)
        if cache is None:
            x = self.pad_fn(x)
        else:
            x = torch.cat((cache[:, :, 1:], x), dim=2)
            cache = x
        x, cache = preprocess_for_attn(inputs, mask, cache, self.pad_fn)
        x = self.fsmn_block(x)
        x = x.transpose(1, 2)