liugz18
2024-07-18 d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99
funasr/models/branchformer/fastformer.py
@@ -81,14 +81,11 @@
        # (batch, n_heads, time)
        query_for_score = (
            self.query_att(mixed_query_layer).transpose(1, 2)
            / self.attention_head_size**0.5
            self.query_att(mixed_query_layer).transpose(1, 2) / self.attention_head_size**0.5
        )
        if mask is not None:
            min_value = float(
                numpy.finfo(
                    torch.tensor(0, dtype=query_for_score.dtype).numpy().dtype
                ).min
                numpy.finfo(torch.tensor(0, dtype=query_for_score.dtype).numpy().dtype).min
            )
            query_for_score = query_for_score.masked_fill(mask, min_value)
            query_weight = torch.softmax(query_for_score, dim=-1).masked_fill(mask, 0.0)
@@ -108,9 +105,7 @@
        pooled_query = self.dropout(pooled_query)
        pooled_query_repeat = pooled_query.repeat(1, seq_len, 1)  # (batch, time, size)
        mixed_query_key_layer = (
            mixed_key_layer * pooled_query_repeat
        )  # (batch, time, size)
        mixed_query_key_layer = mixed_key_layer * pooled_query_repeat  # (batch, time, size)
        # (batch, n_heads, time)
        query_key_score = (
@@ -118,14 +113,10 @@
        ).transpose(1, 2)
        if mask is not None:
            min_value = float(
                numpy.finfo(
                    torch.tensor(0, dtype=query_key_score.dtype).numpy().dtype
                ).min
                numpy.finfo(torch.tensor(0, dtype=query_key_score.dtype).numpy().dtype).min
            )
            query_key_score = query_key_score.masked_fill(mask, min_value)
            query_key_weight = torch.softmax(query_key_score, dim=-1).masked_fill(
                mask, 0.0
            )
            query_key_weight = torch.softmax(query_key_score, dim=-1).masked_fill(mask, 0.0)
        else:
            query_key_weight = torch.softmax(query_key_score, dim=-1)
@@ -133,9 +124,7 @@
        key_layer = self.transpose_for_scores(
            mixed_query_key_layer
        )  # (batch, n_heads, time, attn_dim)
        pooled_key = torch.matmul(
            query_key_weight, key_layer
        )  # (batch, n_heads, 1, attn_dim)
        pooled_key = torch.matmul(query_key_weight, key_layer)  # (batch, n_heads, 1, attn_dim)
        pooled_key = self.dropout(pooled_key)
        # NOTE: value = query, due to param sharing
@@ -143,11 +132,8 @@
            1, 2
        )  # (batch, time, n_heads, attn_dim)
        weighted_value = weighted_value.reshape(
            weighted_value.shape[:-2]
            + (self.num_attention_heads * self.attention_head_size,)
            weighted_value.shape[:-2] + (self.num_attention_heads * self.attention_head_size,)
        )  # (batch, time, size)
        weighted_value = (
            self.dropout(self.transform(weighted_value)) + mixed_query_layer
        )
        weighted_value = self.dropout(self.transform(weighted_value)) + mixed_query_layer
        return weighted_value