From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/models/sanm/attention.py |  380 +++++++++++++++++++++++++++++++++++++++++++++--------
 1 files changed, 321 insertions(+), 59 deletions(-)

diff --git a/funasr/models/sanm/attention.py b/funasr/models/sanm/attention.py
index 10f0a3b..c7e8a8e 100644
--- a/funasr/models/sanm/attention.py
+++ b/funasr/models/sanm/attention.py
@@ -17,6 +17,25 @@
 from funasr.models.transformer.utils.nets_utils import make_pad_mask
 import funasr.models.lora.layers as lora
 
+
+def preprocess_for_attn(x, mask, cache, pad_fn, kernel_size):
+    x = x * mask
+    x = x.transpose(1, 2)
+    if cache is None:
+        x = pad_fn(x)
+    else:
+        x = torch.cat((cache, x), dim=2)
+        cache = x[:, :, -(kernel_size - 1) :]
+    return x, cache
+
+
+torch_version = tuple([int(i) for i in torch.__version__.split(".")[:2]])
+if torch_version >= (1, 8):
+    import torch.fx
+
+    torch.fx.wrap("preprocess_for_attn")
+
+
 class MultiHeadedAttention(nn.Module):
     """Multi-Head Attention layer.
 
@@ -81,9 +100,9 @@
         n_batch = value.size(0)
         if mask is not None:
             mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
-            min_value = float(
-                numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
-            )
+            min_value = -float(
+                "inf"
+            )  # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
             scores = scores.masked_fill(mask, min_value)
             self.attn = torch.softmax(scores, dim=-1).masked_fill(
                 mask, 0.0
@@ -118,10 +137,6 @@
         return self.forward_attention(v, scores, mask)
 
 
-
-
-
-
 class MultiHeadedAttentionSANM(nn.Module):
     """Multi-Head Attention layer.
 
@@ -132,7 +147,19 @@
 
     """
 
-    def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1):
+    def __init__(
+        self,
+        n_head,
+        in_feat,
+        n_feat,
+        dropout_rate,
+        kernel_size,
+        sanm_shfit=0,
+        lora_list=None,
+        lora_rank=8,
+        lora_alpha=16,
+        lora_dropout=0.1,
+    ):
         """Construct an MultiHeadedAttention object."""
         super().__init__()
         assert n_feat % n_head == 0
@@ -144,21 +171,32 @@
         # self.linear_v = nn.Linear(n_feat, n_feat)
         if lora_list is not None:
             if "o" in lora_list:
-                self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+                self.linear_out = lora.Linear(
+                    n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout
+                )
             else:
                 self.linear_out = nn.Linear(n_feat, n_feat)
             lora_qkv_list = ["q" in lora_list, "k" in lora_list, "v" in lora_list]
             if lora_qkv_list == [False, False, False]:
                 self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
             else:
-                self.linear_q_k_v = lora.MergedLinear(in_feat, n_feat * 3, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_qkv_list)
+                self.linear_q_k_v = lora.MergedLinear(
+                    in_feat,
+                    n_feat * 3,
+                    r=lora_rank,
+                    lora_alpha=lora_alpha,
+                    lora_dropout=lora_dropout,
+                    enable_lora=lora_qkv_list,
+                )
         else:
             self.linear_out = nn.Linear(n_feat, n_feat)
             self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
         self.attn = None
         self.dropout = nn.Dropout(p=dropout_rate)
 
-        self.fsmn_block = nn.Conv1d(n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
+        self.fsmn_block = nn.Conv1d(
+            n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False
+        )
         # padding
         left_padding = (kernel_size - 1) // 2
         if sanm_shfit > 0:
@@ -201,9 +239,15 @@
         b, t, d = x.size()
         q_k_v = self.linear_q_k_v(x)
         q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
-        q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time1, d_k)
-        k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time2, d_k)
-        v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time2, d_k)
+        q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(
+            1, 2
+        )  # (batch, head, time1, d_k)
+        k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(
+            1, 2
+        )  # (batch, head, time2, d_k)
+        v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(
+            1, 2
+        )  # (batch, head, time2, d_k)
 
         return q_h, k_h, v_h, v
 
@@ -227,9 +271,9 @@
 
             mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
 
-            min_value = float(
-                numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
-            )
+            min_value = -float(
+                "inf"
+            )  # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
             scores = scores.masked_fill(mask, min_value)
             self.attn = torch.softmax(scores, dim=-1).masked_fill(
                 mask, 0.0
@@ -283,19 +327,21 @@
         q_h, k_h, v_h, v = self.forward_qkv(x)
         if chunk_size is not None and look_back > 0 or look_back == -1:
             if cache is not None:
-                k_h_stride = k_h[:, :, :-(chunk_size[2]), :]
-                v_h_stride = v_h[:, :, :-(chunk_size[2]), :]
+                k_h_stride = k_h[:, :, : -(chunk_size[2]), :]
+                v_h_stride = v_h[:, :, : -(chunk_size[2]), :]
                 k_h = torch.cat((cache["k"], k_h), dim=2)
                 v_h = torch.cat((cache["v"], v_h), dim=2)
 
                 cache["k"] = torch.cat((cache["k"], k_h_stride), dim=2)
                 cache["v"] = torch.cat((cache["v"], v_h_stride), dim=2)
                 if look_back != -1:
-                    cache["k"] = cache["k"][:, :, -(look_back * chunk_size[1]):, :]
-                    cache["v"] = cache["v"][:, :, -(look_back * chunk_size[1]):, :]
+                    cache["k"] = cache["k"][:, :, -(look_back * chunk_size[1]) :, :]
+                    cache["v"] = cache["v"][:, :, -(look_back * chunk_size[1]) :, :]
             else:
-                cache_tmp = {"k": k_h[:, :, :-(chunk_size[2]), :],
-                             "v": v_h[:, :, :-(chunk_size[2]), :]}
+                cache_tmp = {
+                    "k": k_h[:, :, : -(chunk_size[2]), :],
+                    "v": v_h[:, :, : -(chunk_size[2]), :],
+                }
                 cache = cache_tmp
         fsmn_memory = self.forward_fsmn(v, None)
         q_h = q_h * self.d_k ** (-0.5)
@@ -303,6 +349,123 @@
         att_outs = self.forward_attention(v_h, scores, None)
         return att_outs + fsmn_memory, cache
 
+
+class MultiHeadedAttentionSANMExport(nn.Module):
+    def __init__(self, model):
+        super().__init__()
+        self.d_k = model.d_k
+        self.h = model.h
+        self.linear_out = model.linear_out
+        self.linear_q_k_v = model.linear_q_k_v
+        self.fsmn_block = model.fsmn_block
+        self.pad_fn = model.pad_fn
+
+        self.attn = None
+        self.all_head_size = self.h * self.d_k
+
+    def forward(self, x, mask):
+        mask_3d_btd, mask_4d_bhlt = mask
+        q_h, k_h, v_h, v = self.forward_qkv(x)
+        fsmn_memory = self.forward_fsmn(v, mask_3d_btd)
+        q_h = q_h * self.d_k ** (-0.5)
+        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+        att_outs = self.forward_attention(v_h, scores, mask_4d_bhlt)
+        return att_outs + fsmn_memory
+
+    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+        new_x_shape = x.size()[:-1] + (self.h, self.d_k)
+        x = x.view(new_x_shape)
+        return x.permute(0, 2, 1, 3)
+
+    def forward_qkv(self, x):
+        q_k_v = self.linear_q_k_v(x)
+        q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
+        q_h = self.transpose_for_scores(q)
+        k_h = self.transpose_for_scores(k)
+        v_h = self.transpose_for_scores(v)
+        return q_h, k_h, v_h, v
+
+    def forward_fsmn(self, inputs, mask):
+        # b, t, d = inputs.size()
+        # mask = torch.reshape(mask, (b, -1, 1))
+        inputs = inputs * mask
+        x = inputs.transpose(1, 2)
+        x = self.pad_fn(x)
+        x = self.fsmn_block(x)
+        x = x.transpose(1, 2)
+        x = x + inputs
+        x = x * mask
+        return x
+
+    def forward_attention(self, value, scores, mask):
+        scores = scores + mask
+
+        self.attn = torch.softmax(scores, dim=-1)
+        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
+
+        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+        context_layer = context_layer.view(new_context_layer_shape)
+        return self.linear_out(context_layer)  # (batch, time1, d_model)
+
+
+class MultiHeadedAttentionSANMExport(nn.Module):
+    def __init__(self, model):
+        super().__init__()
+        self.d_k = model.d_k
+        self.h = model.h
+        self.linear_out = model.linear_out
+        self.linear_q_k_v = model.linear_q_k_v
+        self.fsmn_block = model.fsmn_block
+        self.pad_fn = model.pad_fn
+
+        self.attn = None
+        self.all_head_size = self.h * self.d_k
+
+    def forward(self, x, mask):
+        mask_3d_btd, mask_4d_bhlt = mask
+        q_h, k_h, v_h, v = self.forward_qkv(x)
+        fsmn_memory = self.forward_fsmn(v, mask_3d_btd)
+        q_h = q_h * self.d_k ** (-0.5)
+        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+        att_outs = self.forward_attention(v_h, scores, mask_4d_bhlt)
+        return att_outs + fsmn_memory
+
+    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+        new_x_shape = x.size()[:-1] + (self.h, self.d_k)
+        x = x.view(new_x_shape)
+        return x.permute(0, 2, 1, 3)
+
+    def forward_qkv(self, x):
+        q_k_v = self.linear_q_k_v(x)
+        q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
+        q_h = self.transpose_for_scores(q)
+        k_h = self.transpose_for_scores(k)
+        v_h = self.transpose_for_scores(v)
+        return q_h, k_h, v_h, v
+
+    def forward_fsmn(self, inputs, mask):
+        # b, t, d = inputs.size()
+        # mask = torch.reshape(mask, (b, -1, 1))
+        inputs = inputs * mask
+        x = inputs.transpose(1, 2)
+        x = self.pad_fn(x)
+        x = self.fsmn_block(x)
+        x = x.transpose(1, 2)
+        x = x + inputs
+        x = x * mask
+        return x
+
+    def forward_attention(self, value, scores, mask):
+        scores = scores + mask
+
+        self.attn = torch.softmax(scores, dim=-1)
+        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
+
+        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+        context_layer = context_layer.view(new_context_layer_shape)
+        return self.linear_out(context_layer)  # (batch, time1, d_model)
 
 
 class MultiHeadedAttentionSANMDecoder(nn.Module):
@@ -317,12 +480,13 @@
 
     def __init__(self, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
         """Construct an MultiHeadedAttention object."""
-        super(MultiHeadedAttentionSANMDecoder, self).__init__()
+        super().__init__()
 
         self.dropout = nn.Dropout(p=dropout_rate)
 
-        self.fsmn_block = nn.Conv1d(n_feat, n_feat,
-                                    kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
+        self.fsmn_block = nn.Conv1d(
+            n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False
+        )
         # padding
         # padding
         left_padding = (kernel_size - 1) // 2
@@ -333,17 +497,17 @@
         self.kernel_size = kernel_size
 
     def forward(self, inputs, mask, cache=None, mask_shfit_chunk=None):
-        '''
+        """
         :param x: (#batch, time1, size).
         :param mask: Mask tensor (#batch, 1, time)
         :return:
-        '''
+        """
         # print("in fsmn, inputs", inputs.size())
         b, t, d = inputs.size()
         # logging.info(
         #     "mask: {}".format(mask.size()))
         if mask is not None:
-            mask = torch.reshape(mask, (b ,-1, 1))
+            mask = torch.reshape(mask, (b, -1, 1))
             # logging.info("in fsmn, mask: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
             if mask_shfit_chunk is not None:
                 # logging.info("in fsmn, mask_fsmn: {}, {}".format(mask_shfit_chunk.size(), mask_shfit_chunk[0:100:50, :, :]))
@@ -367,7 +531,7 @@
             # if t < self.kernel_size:
             #     x = self.pad_fn(x)
             x = torch.cat((cache[:, :, 1:], x), dim=2)
-            x = x[:, :, -(self.kernel_size+t-1):]
+            x = x[:, :, -(self.kernel_size + t - 1) :]
             # print("in fsmn, cache is not None, x_cat", x.size())
             cache = x
         x = self.fsmn_block(x)
@@ -382,6 +546,25 @@
             x = x * mask
         return x, cache
 
+
+class MultiHeadedAttentionSANMDecoderExport(nn.Module):
+    def __init__(self, model):
+        super().__init__()
+        self.fsmn_block = model.fsmn_block
+        self.pad_fn = model.pad_fn
+        self.kernel_size = model.kernel_size
+        self.attn = None
+
+    def forward(self, inputs, mask, cache=None):
+        x, cache = preprocess_for_attn(inputs, mask, cache, self.pad_fn, self.kernel_size)
+        x = self.fsmn_block(x)
+        x = x.transpose(1, 2)
+
+        x = x + inputs
+        x = x * mask
+        return x, cache
+
+
 class MultiHeadedAttentionCrossAtt(nn.Module):
     """Multi-Head Attention layer.
 
@@ -392,31 +575,55 @@
 
     """
 
-    def __init__(self, n_head, n_feat, dropout_rate, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1, encoder_output_size=None):
+    def __init__(
+        self,
+        n_head,
+        n_feat,
+        dropout_rate,
+        lora_list=None,
+        lora_rank=8,
+        lora_alpha=16,
+        lora_dropout=0.1,
+        encoder_output_size=None,
+    ):
         """Construct an MultiHeadedAttention object."""
-        super(MultiHeadedAttentionCrossAtt, self).__init__()
+        super().__init__()
         assert n_feat % n_head == 0
         # We assume d_v always equals d_k
         self.d_k = n_feat // n_head
         self.h = n_head
         if lora_list is not None:
             if "q" in lora_list:
-                self.linear_q = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+                self.linear_q = lora.Linear(
+                    n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout
+                )
             else:
                 self.linear_q = nn.Linear(n_feat, n_feat)
             lora_kv_list = ["k" in lora_list, "v" in lora_list]
             if lora_kv_list == [False, False]:
-                self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
+                self.linear_k_v = nn.Linear(
+                    n_feat if encoder_output_size is None else encoder_output_size, n_feat * 2
+                )
             else:
-                self.linear_k_v = lora.MergedLinear(n_feat if encoder_output_size is None else encoder_output_size, n_feat * 2, 
-                                      r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_kv_list)
+                self.linear_k_v = lora.MergedLinear(
+                    n_feat if encoder_output_size is None else encoder_output_size,
+                    n_feat * 2,
+                    r=lora_rank,
+                    lora_alpha=lora_alpha,
+                    lora_dropout=lora_dropout,
+                    enable_lora=lora_kv_list,
+                )
             if "o" in lora_list:
-                self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+                self.linear_out = lora.Linear(
+                    n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout
+                )
             else:
                 self.linear_out = nn.Linear(n_feat, n_feat)
         else:
             self.linear_q = nn.Linear(n_feat, n_feat)
-            self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
+            self.linear_k_v = nn.Linear(
+                n_feat if encoder_output_size is None else encoder_output_size, n_feat * 2
+            )
             self.linear_out = nn.Linear(n_feat, n_feat)
         self.attn = None
         self.dropout = nn.Dropout(p=dropout_rate)
@@ -439,13 +646,18 @@
         # print("in forward_qkv, x", x.size())
         b = x.size(0)
         q = self.linear_q(x)
-        q_h = torch.reshape(q, (b, -1, self.h, self.d_k)).transpose(1, 2)    # (batch, head, time1, d_k)
+        q_h = torch.reshape(q, (b, -1, self.h, self.d_k)).transpose(
+            1, 2
+        )  # (batch, head, time1, d_k)
 
         k_v = self.linear_k_v(memory)
-        k, v = torch.split(k_v, int(self.h*self.d_k), dim=-1)
-        k_h = torch.reshape(k, (b, -1, self.h, self.d_k)).transpose(1, 2)    # (batch, head, time2, d_k)
-        v_h = torch.reshape(v, (b, -1, self.h, self.d_k)).transpose(1, 2)    # (batch, head, time2, d_k)
-
+        k, v = torch.split(k_v, int(self.h * self.d_k), dim=-1)
+        k_h = torch.reshape(k, (b, -1, self.h, self.d_k)).transpose(
+            1, 2
+        )  # (batch, head, time2, d_k)
+        v_h = torch.reshape(v, (b, -1, self.h, self.d_k)).transpose(
+            1, 2
+        )  # (batch, head, time2, d_k)
 
         return q_h, k_h, v_h
 
@@ -465,9 +677,9 @@
         n_batch = value.size(0)
         if mask is not None:
             mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
-            min_value = float(
-                numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
-            )
+            min_value = -float(
+                "inf"
+            )  # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
             # logging.info(
             #     "scores: {}, mask_size: {}".format(scores.size(), mask.size()))
             scores = scores.masked_fill(mask, min_value)
@@ -523,15 +735,62 @@
             if cache is not None:
                 k_h = torch.cat((cache["k"], k_h), dim=2)
                 v_h = torch.cat((cache["v"], v_h), dim=2)
-                cache["k"] = k_h[:, :, -(look_back * chunk_size[1]):, :]
-                cache["v"] = v_h[:, :, -(look_back * chunk_size[1]):, :]
+                cache["k"] = k_h[:, :, -(look_back * chunk_size[1]) :, :]
+                cache["v"] = v_h[:, :, -(look_back * chunk_size[1]) :, :]
             else:
-                cache_tmp = {"k": k_h[:, :, -(look_back * chunk_size[1]):, :],
-                             "v": v_h[:, :, -(look_back * chunk_size[1]):, :]}
+                cache_tmp = {
+                    "k": k_h[:, :, -(look_back * chunk_size[1]) :, :],
+                    "v": v_h[:, :, -(look_back * chunk_size[1]) :, :],
+                }
                 cache = cache_tmp
         q_h = q_h * self.d_k ** (-0.5)
         scores = torch.matmul(q_h, k_h.transpose(-2, -1))
         return self.forward_attention(v_h, scores, None), cache
+
+
+class MultiHeadedAttentionCrossAttExport(nn.Module):
+    def __init__(self, model):
+        super().__init__()
+        self.d_k = model.d_k
+        self.h = model.h
+        self.linear_q = model.linear_q
+        self.linear_k_v = model.linear_k_v
+        self.linear_out = model.linear_out
+        self.attn = None
+        self.all_head_size = self.h * self.d_k
+
+    def forward(self, x, memory, memory_mask, ret_attn=False):
+        q, k, v = self.forward_qkv(x, memory)
+        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
+        return self.forward_attention(v, scores, memory_mask, ret_attn)
+
+    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+        new_x_shape = x.size()[:-1] + (self.h, self.d_k)
+        x = x.view(new_x_shape)
+        return x.permute(0, 2, 1, 3)
+
+    def forward_qkv(self, x, memory):
+        q = self.linear_q(x)
+
+        k_v = self.linear_k_v(memory)
+        k, v = torch.split(k_v, int(self.h * self.d_k), dim=-1)
+        q = self.transpose_for_scores(q)
+        k = self.transpose_for_scores(k)
+        v = self.transpose_for_scores(v)
+        return q, k, v
+
+    def forward_attention(self, value, scores, mask, ret_attn):
+        scores = scores + mask.to(scores.device)
+
+        self.attn = torch.softmax(scores, dim=-1)
+        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
+
+        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+        context_layer = context_layer.view(new_context_layer_shape)
+        if ret_attn:
+            return self.linear_out(context_layer), self.attn
+        return self.linear_out(context_layer)  # (batch, time1, d_model)
 
 
 class MultiHeadSelfAttention(nn.Module):
@@ -573,9 +832,15 @@
         b, t, d = x.size()
         q_k_v = self.linear_q_k_v(x)
         q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
-        q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time1, d_k)
-        k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time2, d_k)
-        v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time2, d_k)
+        q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(
+            1, 2
+        )  # (batch, head, time1, d_k)
+        k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(
+            1, 2
+        )  # (batch, head, time2, d_k)
+        v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(
+            1, 2
+        )  # (batch, head, time2, d_k)
 
         return q_h, k_h, v_h, v
 
@@ -599,9 +864,9 @@
 
             mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
 
-            min_value = float(
-                numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
-            )
+            min_value = -float(
+                "inf"
+            )  # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
             scores = scores.masked_fill(mask, min_value)
             self.attn = torch.softmax(scores, dim=-1).masked_fill(
                 mask, 0.0
@@ -636,6 +901,3 @@
         scores = torch.matmul(q_h, k_h.transpose(-2, -1))
         att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
         return att_outs
-
-
-

--
Gitblit v1.9.1