From d0cd484fdc21c06b8bc892bb2ab1c2a25fb1da8a Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 31 三月 2023 15:05:37 +0800
Subject: [PATCH] export
---
funasr/train/abs_model.py | 56 ++++++++++++++++++++++++++++++++++++++++++--------------
1 files changed, 42 insertions(+), 14 deletions(-)
diff --git a/funasr/punctuation/espnet_model.py b/funasr/train/abs_model.py
similarity index 85%
rename from funasr/punctuation/espnet_model.py
rename to funasr/train/abs_model.py
index 7266b38..8bfba45 100644
--- a/funasr/punctuation/espnet_model.py
+++ b/funasr/train/abs_model.py
@@ -1,3 +1,9 @@
+from abc import ABC
+from abc import abstractmethod
+from typing import Tuple
+
+import torch
+
from typing import Dict
from typing import Optional
from typing import Tuple
@@ -7,13 +13,34 @@
from typeguard import check_argument_types
from funasr.modules.nets_utils import make_pad_mask
-from funasr.punctuation.abs_model import AbsPunctuation
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.modules.scorers.scorer_interface import BatchScorerInterface
-class ESPnetPunctuationModel(AbsESPnetModel):
+class AbsPunctuation(torch.nn.Module, BatchScorerInterface, ABC):
+ """The abstract class
+
+ To share the loss calculation way among different models,
+ We uses delegate pattern here:
+ The instance of this class should be passed to "LanguageModel"
+
+ This "model" is one of mediator objects for "Task" class.
+
+ """
+
+ @abstractmethod
+ def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ raise NotImplementedError
+
+ @abstractmethod
+ def with_vad(self) -> bool:
+ raise NotImplementedError
+
+
+class PunctuationModel(AbsESPnetModel):
+
def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0, punc_weight: list = None):
assert check_argument_types()
super().__init__()
@@ -21,12 +48,12 @@
self.punc_weight = torch.Tensor(punc_weight)
self.sos = 1
self.eos = 2
-
+
# ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
self.ignore_id = ignore_id
- #if self.punc_model.with_vad():
+ # if self.punc_model.with_vad():
# print("This is a vad puncuation model.")
-
+
def nll(
self,
text: torch.Tensor,
@@ -54,7 +81,7 @@
else:
text = text[:, :max_length]
punc = punc[:, :max_length]
-
+
if self.punc_model.with_vad():
# Should be VadRealtimeTransformer
assert vad_indexes is not None
@@ -62,7 +89,7 @@
else:
# Should be TargetDelayTransformer,
y, _ = self.punc_model(text, text_lengths)
-
+
# Calc negative log likelihood
# nll: (BxL,)
if self.training == False:
@@ -75,7 +102,8 @@
return nll, text_lengths
else:
self.punc_weight = self.punc_weight.to(punc.device)
- nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none", ignore_index=self.ignore_id)
+ nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none",
+ ignore_index=self.ignore_id)
# nll: (BxL,) -> (BxL,)
if max_length is None:
nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0)
@@ -87,7 +115,7 @@
# nll: (BxL,) -> (B, L)
nll = nll.view(batch_size, -1)
return nll, text_lengths
-
+
def batchify_nll(self,
text: torch.Tensor,
punc: torch.Tensor,
@@ -113,7 +141,7 @@
nlls = []
x_lengths = []
max_length = text_lengths.max()
-
+
start_idx = 0
while True:
end_idx = min(start_idx + batch_size, total_num)
@@ -132,7 +160,7 @@
assert nll.size(0) == total_num
assert x_lengths.size(0) == total_num
return nll, x_lengths
-
+
def forward(
self,
text: torch.Tensor,
@@ -146,15 +174,15 @@
ntokens = y_lengths.sum()
loss = nll.sum() / ntokens
stats = dict(loss=loss.detach())
-
+
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
return loss, stats, weight
-
+
def collect_feats(self, text: torch.Tensor, punc: torch.Tensor,
text_lengths: torch.Tensor) -> Dict[str, torch.Tensor]:
return {}
-
+
def inference(self,
text: torch.Tensor,
text_lengths: torch.Tensor,
--
Gitblit v1.9.1