| | |
| | | import torch |
| | | import torch.nn.functional as F |
| | | from torch import Tensor, nn |
| | | from funasr.models.transformer.utils.mask import subsequent_mask |
| | | |
| | | |
| | | class LayerNorm(nn.LayerNorm): |
| | |
| | | qk = qk + mask[:n_ctx, :n_ctx] |
| | | else: |
| | | mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) |
| | | min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min) |
| | | min_value = -float( |
| | | "inf" |
| | | ) # min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min) |
| | | qk = qk.masked_fill(mask, min_value) |
| | | |
| | | qk = qk.float() |
| | |
| | | |
| | | return x |
| | | |
| | | def init_state(self, x): |
| | | state = {} |
| | | |
| | | return state |
| | | |
| | | def final_score(self, state) -> float: |
| | | """Score eos (optional). |
| | | |
| | | Args: |
| | | state: Scorer state for prefix tokens |
| | | |
| | | Returns: |
| | | float: final score |
| | | |
| | | """ |
| | | return 0.0 |
| | | |
| | | def score(self, ys, state, x): |
| | | """Score.""" |
| | | ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0) |
| | | logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=state) |
| | | return logp.squeeze(0)[-1, :], state |
| | | |
| | | |
| | | class MultiHeadedAttentionSANMDecoder(nn.Module): |
| | | """Multi-Head Attention layer. |
| | |
| | | self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0) |
| | | self.kernel_size = kernel_size |
| | | |
| | | def forward(self, inputs, mask, cache=None, mask_shfit_chunk=None): |
| | | def forward(self, inputs, mask, cache=None, mask_shfit_chunk=None, **kwargs): |
| | | """ |
| | | :param x: (#batch, time1, size). |
| | | :param mask: Mask tensor (#batch, 1, time) |
| | |
| | | kv_cache: Optional[dict] = None, |
| | | **kwargs, |
| | | ): |
| | | cache = kwargs.get("cache", {}) |
| | | layer = kwargs.get("layer", 0) |
| | | is_pad_mask = kwargs.get("is_pad_mask", False) |
| | | is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False) |
| | | x = x + self.attn(self.attn_ln(x), mask=None, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0] |
| | | |
| | | fsmn_cache = cache[layer]["fsmn_cache"] if cache is not None and len(cache) > 0 else None |
| | | # if fsmn_cache is not None: |
| | | # x = x[:, -1:] |
| | | att_res, fsmn_cache = self.attn(self.attn_ln(x), mask=None, cache=fsmn_cache) |
| | | # if len(cache)>1: |
| | | # cache[layer]["fsmn_cache"] = fsmn_cache |
| | | # x = x[:, -1:] |
| | | x = x + att_res |
| | | if self.cross_attn: |
| | | x = ( |
| | | x |
| | |
| | | |
| | | ys_in_lens = kwargs.get("ys_in_lens", None) |
| | | |
| | | offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 |
| | | tgt, memory = x, xa |
| | | tgt[tgt == -1] = 0 |
| | | tgt = self.token_embedding(tgt) + self.positional_embedding[offset : offset + tgt.size(1)] |
| | | tgt = self.token_embedding(tgt) + self.positional_embedding[: tgt.size(1)] |
| | | # tgt = self.dropout(tgt) |
| | | |
| | | x = tgt.to(memory.dtype) |
| | |
| | | memory_mask=memory_mask, |
| | | is_pad_mask=False, |
| | | is_pad_memory_mask=True, |
| | | cache=kwargs.get("cache", None), |
| | | layer=layer, |
| | | ) |
| | | |
| | | x = self.ln(x) |
| | | x = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float() |
| | | |
| | | return x |
| | | |
| | | def init_state(self, x): |
| | | state = {} |
| | | for layer, block in enumerate(self.blocks): |
| | | state[layer] = { |
| | | "fsmn_cache": None, |
| | | "memory_key": None, |
| | | "memory_value": None, |
| | | } |
| | | |
| | | return state |
| | | |
| | | def final_score(self, state) -> float: |
| | | """Score eos (optional). |
| | | |
| | | Args: |
| | | state: Scorer state for prefix tokens |
| | | |
| | | Returns: |
| | | float: final score |
| | | |
| | | """ |
| | | return 0.0 |
| | | |
| | | def score(self, ys, state, x): |
| | | """Score.""" |
| | | ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0) |
| | | logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=None) |
| | | logp = torch.log_softmax(logp, dim=-1) |
| | | return logp.squeeze(0)[-1, :], state |