From 62178770dccdbf5da42e831898ea32adeeacba45 Mon Sep 17 00:00:00 2001
From: 语帆 <yf352572@alibaba-inc.com>
Date: 星期三, 21 二月 2024 20:04:01 +0800
Subject: [PATCH] test
---
funasr/models/contextual_paraformer/model.py | 29 +++++++++++------------------
1 files changed, 11 insertions(+), 18 deletions(-)
diff --git a/funasr/models/contextual_paraformer/model.py b/funasr/models/contextual_paraformer/model.py
index 10bbf9d..1c0805a 100644
--- a/funasr/models/contextual_paraformer/model.py
+++ b/funasr/models/contextual_paraformer/model.py
@@ -102,17 +102,16 @@
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
- pdb.set_trace()
+
batch_size = speech.shape[0]
hotword_pad = kwargs.get("hotword_pad")
hotword_lengths = kwargs.get("hotword_lengths")
dha_pad = kwargs.get("dha_pad")
- pdb.set_trace()
+
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
- pdb.set_trace()
loss_ctc, cer_ctc = None, None
stats = dict()
@@ -127,12 +126,11 @@
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
stats["cer_ctc"] = cer_ctc
- pdb.set_trace()
# 2b. Attention decoder branch
loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal = self._calc_att_clas_loss(
encoder_out, encoder_out_lens, text, text_lengths, hotword_pad, hotword_lengths
)
- pdb.set_trace()
+
# 3. CTC-Att loss definition
if self.ctc_weight == 0.0:
loss = loss_att + loss_pre * self.predictor_weight
@@ -170,26 +168,24 @@
):
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
- pdb.set_trace()
+
if self.predictor_bias == 1:
_, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_pad_lens = ys_pad_lens + self.predictor_bias
- pdb.set_trace()
+
pre_acoustic_embeds, pre_token_length, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
ignore_id=self.ignore_id)
- pdb.set_trace()
# -1. bias encoder
if self.use_decoder_embedding:
hw_embed = self.decoder.embed(hotword_pad)
else:
hw_embed = self.bias_embed(hotword_pad)
- pdb.set_trace()
+
hw_embed, (_, _) = self.bias_encoder(hw_embed)
- pdb.set_trace()
_ind = np.arange(0, hotword_pad.shape[0]).tolist()
selected = hw_embed[_ind, [i - 1 for i in hotword_lengths.detach().cpu().tolist()]]
contextual_info = selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device)
- pdb.set_trace()
+
# 0. sampler
decoder_out_1st = None
if self.sampling_ratio > 0.0:
@@ -201,7 +197,7 @@
if self.step_cur < 2:
logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
sematic_embeds = pre_acoustic_embeds
- pdb.set_trace()
+
# 1. Forward decoder
decoder_outs = self.decoder(
encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
@@ -217,7 +213,7 @@
loss_ideal = None
'''
loss_ideal = None
- pdb.set_trace()
+
if decoder_out_1st is None:
decoder_out_1st = decoder_out
# 2. Compute attention loss
@@ -294,11 +290,11 @@
enforce_sorted=False)
_, (h_n, _) = self.bias_encoder(hw_embed)
hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
- pdb.set_trace()
+
decoder_outs = self.decoder(
encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed, clas_scale=clas_scale
)
- pdb.set_trace()
+
decoder_out = decoder_outs[0]
decoder_out = torch.log_softmax(decoder_out, dim=-1)
return decoder_out, ys_pad_lens
@@ -363,14 +359,11 @@
clas_scale=kwargs.get("clas_scale", 1.0))
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
- pdb.set_trace()
results = []
b, n, d = decoder_out.size()
- pdb.set_trace()
for i in range(b):
x = encoder_out[i, :encoder_out_lens[i], :]
am_scores = decoder_out[i, :pre_token_length[i], :]
- pdb.set_trace()
if self.beam_search is not None:
nbest_hyps = self.beam_search(
x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
--
Gitblit v1.9.1