From 1d97d628f2f19674fa50495e984db8185604ca8e Mon Sep 17 00:00:00 2001
From: lzr265946 <lzr265946@alibaba-inc.com>
Date: 星期五, 03 二月 2023 14:11:22 +0800
Subject: [PATCH] Merge branch 'main' into dev
---
funasr/punctuation/target_delay_transformer.py | 30 ++++++++++--------------------
1 files changed, 10 insertions(+), 20 deletions(-)
diff --git a/funasr/punctuation/target_delay_transformer.py b/funasr/punctuation/target_delay_transformer.py
index 66025cb..10cc5a8 100644
--- a/funasr/punctuation/target_delay_transformer.py
+++ b/funasr/punctuation/target_delay_transformer.py
@@ -14,6 +14,7 @@
class TargetDelayTransformer(AbsPunctuation):
+
def __init__(
self,
vocab_size: int,
@@ -28,7 +29,7 @@
):
super().__init__()
if pos_enc == "sinusoidal":
-# pos_enc_class = PositionalEncoding
+ # pos_enc_class = PositionalEncoding
pos_enc_class = SinusoidalPositionEncoder
elif pos_enc is None:
@@ -47,16 +48,16 @@
num_blocks=layer,
dropout_rate=dropout_rate,
input_layer="pe",
- # pos_enc_class=pos_enc_class,
+ # pos_enc_class=pos_enc_class,
padding_idx=0,
)
self.decoder = nn.Linear(att_unit, punc_size)
+
# def _target_mask(self, ys_in_pad):
# ys_mask = ys_in_pad != 0
# m = subsequent_n_mask(ys_mask.size(-1), 5, device=ys_mask.device).unsqueeze(0)
# return ys_mask.unsqueeze(-2) & m
-
def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
"""Compute loss value from buffer sequences.
@@ -67,14 +68,12 @@
"""
x = self.embed(input)
- # mask = self._target_mask(input)
+ # mask = self._target_mask(input)
h, _, _ = self.encoder(x, text_lengths)
y = self.decoder(h)
return y, None
- def score(
- self, y: torch.Tensor, state: Any, x: torch.Tensor
- ) -> Tuple[torch.Tensor, Any]:
+ def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
"""Score new token.
Args:
@@ -89,16 +88,12 @@
"""
y = y.unsqueeze(0)
- h, _, cache = self.encoder.forward_one_step(
- self.embed(y), self._target_mask(y), cache=state
- )
+ h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
h = self.decoder(h[:, -1])
logp = h.log_softmax(dim=-1).squeeze(0)
return logp, cache
- def batch_score(
- self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
- ) -> Tuple[torch.Tensor, List[Any]]:
+ def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
"""Score new token batch.
Args:
@@ -120,15 +115,10 @@
batch_state = None
else:
# transpose state of [batch, layer] into [layer, batch]
- batch_state = [
- torch.stack([states[b][i] for b in range(n_batch)])
- for i in range(n_layers)
- ]
+ batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
# batch decoding
- h, _, states = self.encoder.forward_one_step(
- self.embed(ys), self._target_mask(ys), cache=batch_state
- )
+ h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
h = self.decoder(h[:, -1])
logp = h.log_softmax(dim=-1)
--
Gitblit v1.9.1