import copy
|
from typing import Optional, Tuple, Union
|
|
import torch
|
import torch.nn as nn
|
import torch.nn.functional as F
|
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
|
def sense_voice_decode_forward(
|
self,
|
x: torch.Tensor,
|
xa: torch.Tensor,
|
kv_cache: Optional[dict] = None,
|
**kwargs,
|
):
|
"""Forward decoder.
|
|
Args:
|
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
|
hlens: (batch)
|
ys_in_pad:
|
input token ids, int64 (batch, maxlen_out)
|
if input_layer == "embed"
|
input tensor (batch, maxlen_out, #mels) in the other cases
|
ys_in_lens: (batch)
|
Returns:
|
(tuple): tuple containing:
|
|
x: decoded token score before softmax (batch, maxlen_out, token)
|
if use_output_layer is True,
|
olens: (batch, )
|
"""
|
# import pdb;pdb.set_trace()
|
use_padmask = self.use_padmask
|
hlens = kwargs.get("hlens", None)
|
|
ys_in_lens = kwargs.get("ys_in_lens", None)
|
|
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
tgt, memory = x, xa
|
tgt[tgt==-1] = 0
|
tgt = (
|
self.token_embedding(tgt)
|
+ self.positional_embedding[offset : offset + tgt.size(1)]
|
)
|
# tgt = self.dropout(tgt)
|
|
x = tgt.to(memory.dtype)
|
|
if use_padmask and hlens is not None:
|
memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device)
|
else:
|
memory_mask = None
|
|
for layer, block in enumerate(self.blocks):
|
x = block(x, memory, mask=self.mask, memory_mask=memory_mask, is_pad_mask=False, is_pad_memory_mask=True)
|
|
|
x = self.ln(x)
|
x = (
|
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
|
).float()
|
|
|
return x
|
|