From 9b4e9cc8a0311e5243d69b73ed073e7ea441982e Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 27 三月 2024 16:05:29 +0800
Subject: [PATCH] train update
---
funasr/models/paraformer/model.py | 39 ++++++++++++++++++++++++++-------------
1 files changed, 26 insertions(+), 13 deletions(-)
diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py
index f92441d..bd85df0 100644
--- a/funasr/models/paraformer/model.py
+++ b/funasr/models/paraformer/model.py
@@ -13,13 +13,14 @@
from funasr.models.ctc.ctc import CTC
from funasr.utils import postprocess_utils
from funasr.metrics.compute_acc import th_accuracy
+from funasr.train_utils.device_funcs import to_device
from funasr.utils.datadir_writer import DatadirWriter
from funasr.models.paraformer.search import Hypothesis
from funasr.models.paraformer.cif_predictor import mae_loss
from funasr.train_utils.device_funcs import force_gatherable
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
-from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
@@ -33,7 +34,6 @@
def __init__(
self,
- # token_list: Union[Tuple[str, ...], List[str]],
specaug: Optional[str] = None,
specaug_conf: Optional[Dict] = None,
normalize: str = None,
@@ -155,8 +155,8 @@
self.predictor_bias = predictor_bias
self.sampling_ratio = sampling_ratio
self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
- # self.step_cur = 0
- #
+
+
self.share_embedding = share_embedding
if self.share_embedding:
self.decoder.embed = None
@@ -164,6 +164,7 @@
self.use_1st_decoder_loss = use_1st_decoder_loss
self.length_normalized_loss = length_normalized_loss
self.beam_search = None
+ self.error_calculator = None
def forward(
self,
@@ -230,6 +231,7 @@
stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
stats["loss"] = torch.clone(loss.detach())
+ stats["batch_size"] = batch_size
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
@@ -451,11 +453,13 @@
self.nbest = kwargs.get("nbest", 1)
meta_data = {}
- if isinstance(data_in, torch.Tensor): # fbank
+ if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
speech, speech_lengths = data_in, data_lengths
if len(speech.shape) < 3:
speech = speech[None, :, :]
- if speech_lengths is None:
+ if speech_lengths is not None:
+ speech_lengths = speech_lengths.squeeze(-1)
+ else:
speech_lengths = speech.shape[1]
else:
# extract fbank feats
@@ -491,6 +495,8 @@
b, n, d = decoder_out.size()
if isinstance(key[0], (list, tuple)):
key = key[0]
+ if len(key) < b:
+ key = key*b
for i in range(b):
x = encoder_out[i, :encoder_out_lens[i], :]
am_scores = decoder_out[i, :pre_token_length[i], :]
@@ -512,9 +518,10 @@
nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
for nbest_idx, hyp in enumerate(nbest_hyps):
ibest_writer = None
- if ibest_writer is None and kwargs.get("output_dir") is not None:
- writer = DatadirWriter(kwargs.get("output_dir"))
- ibest_writer = writer[f"{nbest_idx+1}best_recog"]
+ if kwargs.get("output_dir") is not None:
+ if not hasattr(self, "writer"):
+ self.writer = DatadirWriter(kwargs.get("output_dir"))
+ ibest_writer = self.writer[f"{nbest_idx+1}best_recog"]
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
@@ -528,13 +535,12 @@
if tokenizer is not None:
# Change integer-ids to tokens
token = tokenizer.ids2tokens(token_int)
- text = tokenizer.tokens2text(token)
-
- text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+ text_postprocessed = tokenizer.tokens2text(token)
+ if not hasattr(tokenizer, "bpemodel"):
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
result_i = {"key": key[i], "text": text_postprocessed}
-
if ibest_writer is not None:
ibest_writer["token"][key[i]] = " ".join(token)
# ibest_writer["text"][key[i]] = text
@@ -545,3 +551,10 @@
return results, meta_data
+ def export(self, **kwargs):
+ from .export_meta import export_rebuild_model
+ if 'max_seq_len' not in kwargs:
+ kwargs['max_seq_len'] = 512
+ models = export_rebuild_model(model=self, **kwargs)
+ return models
+
--
Gitblit v1.9.1