仁迷
2023-03-13 3762d21300e1f3fa3e0cb1e67545227e6dcec3de
funasr/modules/attention.py
@@ -347,15 +347,17 @@
            mask = torch.reshape(mask, (b, -1, 1))
            if mask_shfit_chunk is not None:
                mask = mask * mask_shfit_chunk
            inputs = inputs * mask
        inputs = inputs * mask
        x = inputs.transpose(1, 2)
        x = self.pad_fn(x)
        x = self.fsmn_block(x)
        x = x.transpose(1, 2)
        x += inputs
        x = self.dropout(x)
        return x * mask
        if mask is not None:
            x = x * mask
        return x
    def forward_qkv(self, x):
        """Transform query, key and value.
@@ -505,7 +507,7 @@
            # print("in fsmn, cache is None, x", x.size())
            x = self.pad_fn(x)
            if not self.training and t <= 1:
            if not self.training:
                cache = x
        else:
            # print("in fsmn, cache is not None, x", x.size())
@@ -513,7 +515,7 @@
            # if t < self.kernel_size:
            #     x = self.pad_fn(x)
            x = torch.cat((cache[:, :, 1:], x), dim=2)
            x = x[:, :, -self.kernel_size:]
            x = x[:, :, -(self.kernel_size+t-1):]
            # print("in fsmn, cache is not None, x_cat", x.size())
            cache = x
        x = self.fsmn_block(x)