| | |
| | | |
| | | """ |
| | | q_h, k_h, v_h, v = self.forward_qkv(x) |
| | | if chunk_size is not None and look_back > 0: |
| | | if chunk_size is not None and look_back > 0 or look_back == -1: |
| | | if cache is not None: |
| | | k_h_stride = k_h[:, :, :-(chunk_size[2]), :] |
| | | v_h_stride = v_h[:, :, :-(chunk_size[2]), :] |
| | | k_h = torch.cat((cache["k"], k_h), dim=2) |
| | | v_h = torch.cat((cache["v"], v_h), dim=2) |
| | | cache["k"] = k_h[:, :, -(look_back * chunk_size[1]):, :] |
| | | cache["v"] = v_h[:, :, -(look_back * chunk_size[1]):, :] |
| | | |
| | | cache["k"] = torch.cat((cache["k"], k_h_stride), dim=2) |
| | | cache["v"] = torch.cat((cache["v"], v_h_stride), dim=2) |
| | | if look_back != -1: |
| | | cache["k"] = cache["k"][:, :, -(look_back * chunk_size[1]):, :] |
| | | cache["v"] = cache["v"][:, :, -(look_back * chunk_size[1]):, :] |
| | | else: |
| | | cache_tmp = {"k": k_h[:, :, -(look_back * chunk_size[1]):, :], |
| | | "v": v_h[:, :, -(look_back * chunk_size[1]):, :]} |
| | | cache_tmp = {"k": k_h[:, :, :-(chunk_size[2]), :], |
| | | "v": v_h[:, :, :-(chunk_size[2]), :]} |
| | | cache = cache_tmp |
| | | fsmn_memory = self.forward_fsmn(v, None) |
| | | q_h = q_h * self.d_k ** (-0.5) |