From 62178770dccdbf5da42e831898ea32adeeacba45 Mon Sep 17 00:00:00 2001
From: 语帆 <yf352572@alibaba-inc.com>
Date: 星期三, 21 二月 2024 20:04:01 +0800
Subject: [PATCH] test

---
 funasr/models/contextual_paraformer/model.py |   29 +++++++++++------------------
 1 files changed, 11 insertions(+), 18 deletions(-)

diff --git a/funasr/models/contextual_paraformer/model.py b/funasr/models/contextual_paraformer/model.py
index 10bbf9d..1c0805a 100644
--- a/funasr/models/contextual_paraformer/model.py
+++ b/funasr/models/contextual_paraformer/model.py
@@ -102,17 +102,16 @@
             text_lengths = text_lengths[:, 0]
         if len(speech_lengths.size()) > 1:
             speech_lengths = speech_lengths[:, 0]
-        pdb.set_trace()
+
         batch_size = speech.shape[0]
 
         hotword_pad = kwargs.get("hotword_pad")
         hotword_lengths = kwargs.get("hotword_lengths")
         dha_pad = kwargs.get("dha_pad")
-        pdb.set_trace()
+
         # 1. Encoder
         encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
 
-        pdb.set_trace()
         loss_ctc, cer_ctc = None, None
         
         stats = dict()
@@ -127,12 +126,11 @@
             stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
             stats["cer_ctc"] = cer_ctc
         
-        pdb.set_trace()
         # 2b. Attention decoder branch
         loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal = self._calc_att_clas_loss(
             encoder_out, encoder_out_lens, text, text_lengths, hotword_pad, hotword_lengths
         )
-        pdb.set_trace()
+
         # 3. CTC-Att loss definition
         if self.ctc_weight == 0.0:
             loss = loss_att + loss_pre * self.predictor_weight
@@ -170,26 +168,24 @@
     ):
         encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
             encoder_out.device)
-        pdb.set_trace()
+
         if self.predictor_bias == 1:
             _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
             ys_pad_lens = ys_pad_lens + self.predictor_bias
-        pdb.set_trace()
+
         pre_acoustic_embeds, pre_token_length, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
                                                                      ignore_id=self.ignore_id)
-        pdb.set_trace()
         # -1. bias encoder
         if self.use_decoder_embedding:
             hw_embed = self.decoder.embed(hotword_pad)
         else:
             hw_embed = self.bias_embed(hotword_pad)
-        pdb.set_trace()
+
         hw_embed, (_, _) = self.bias_encoder(hw_embed)
-        pdb.set_trace()
         _ind = np.arange(0, hotword_pad.shape[0]).tolist()
         selected = hw_embed[_ind, [i - 1 for i in hotword_lengths.detach().cpu().tolist()]]
         contextual_info = selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device)
-        pdb.set_trace()
+
         # 0. sampler
         decoder_out_1st = None
         if self.sampling_ratio > 0.0:
@@ -201,7 +197,7 @@
             if self.step_cur < 2:
                 logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
             sematic_embeds = pre_acoustic_embeds
-        pdb.set_trace()
+
         # 1. Forward decoder
         decoder_outs = self.decoder(
             encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
@@ -217,7 +213,7 @@
             loss_ideal = None
         '''
         loss_ideal = None
-        pdb.set_trace()
+
         if decoder_out_1st is None:
             decoder_out_1st = decoder_out
         # 2. Compute attention loss
@@ -294,11 +290,11 @@
                                                                enforce_sorted=False)
             _, (h_n, _) = self.bias_encoder(hw_embed)
             hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
-        pdb.set_trace()
+
         decoder_outs = self.decoder(
             encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed, clas_scale=clas_scale
         )
-        pdb.set_trace()
+
         decoder_out = decoder_outs[0]
         decoder_out = torch.log_softmax(decoder_out, dim=-1)
         return decoder_out, ys_pad_lens
@@ -363,14 +359,11 @@
                                                                  clas_scale=kwargs.get("clas_scale", 1.0))
         decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
         
-        pdb.set_trace()
         results = []
         b, n, d = decoder_out.size()
-        pdb.set_trace()
         for i in range(b):
             x = encoder_out[i, :encoder_out_lens[i], :]
             am_scores = decoder_out[i, :pre_token_length[i], :]
-            pdb.set_trace()
             if self.beam_search is not None:
                 nbest_hyps = self.beam_search(
                     x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),

--
Gitblit v1.9.1