From 5608bee0accea5e12030f8e1b6f7d62eee4dd892 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 01 三月 2024 18:36:40 +0800
Subject: [PATCH] fixbug (#1412)

---
 funasr/models/paraformer/model.py            |    4 ++--
 funasr/models/seaco_paraformer/model.py      |    1 -
 funasr/models/contextual_paraformer/model.py |    5 +----
 funasr/models/paraformer_streaming/model.py  |    5 +----
 4 files changed, 4 insertions(+), 11 deletions(-)

diff --git a/funasr/models/contextual_paraformer/model.py b/funasr/models/contextual_paraformer/model.py
index 939d46d..49868a8 100644
--- a/funasr/models/contextual_paraformer/model.py
+++ b/funasr/models/contextual_paraformer/model.py
@@ -190,13 +190,10 @@
         # 0. sampler
         decoder_out_1st = None
         if self.sampling_ratio > 0.0:
-            if self.step_cur < 2:
-                logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+
             sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
                                                            pre_acoustic_embeds, contextual_info)
         else:
-            if self.step_cur < 2:
-                logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
             sematic_embeds = pre_acoustic_embeds
         
         # 1. Forward decoder
diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py
index 90ce162..51c9bdb 100644
--- a/funasr/models/paraformer/model.py
+++ b/funasr/models/paraformer/model.py
@@ -154,8 +154,8 @@
         self.predictor_bias = predictor_bias
         self.sampling_ratio = sampling_ratio
         self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
-        # self.step_cur = 0
-        #
+
+
         self.share_embedding = share_embedding
         if self.share_embedding:
             self.decoder.embed = None
diff --git a/funasr/models/paraformer_streaming/model.py b/funasr/models/paraformer_streaming/model.py
index 45875a2..4cf20de 100644
--- a/funasr/models/paraformer_streaming/model.py
+++ b/funasr/models/paraformer_streaming/model.py
@@ -235,8 +235,7 @@
         decoder_out_1st = None
         pre_loss_att = None
         if self.sampling_ratio > 0.0:
-            if self.step_cur < 2:
-                logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+
             if self.use_1st_decoder_loss:
                 sematic_embeds, decoder_out_1st, pre_loss_att = \
                     self.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad,
@@ -246,8 +245,6 @@
                     self.sampler(encoder_out, encoder_out_lens, ys_pad,
                                  ys_pad_lens, pre_acoustic_embeds, scama_mask)
         else:
-            if self.step_cur < 2:
-                logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
             sematic_embeds = pre_acoustic_embeds
         
         # 1. Forward decoder
diff --git a/funasr/models/seaco_paraformer/model.py b/funasr/models/seaco_paraformer/model.py
index 21ad874..20b0cc8 100644
--- a/funasr/models/seaco_paraformer/model.py
+++ b/funasr/models/seaco_paraformer/model.py
@@ -130,7 +130,6 @@
         dha_pad = kwargs.get("dha_pad")
 
         batch_size = speech.shape[0]
-        self.step_cur += 1
         # for data-parallel
         text = text[:, : text_lengths.max()]
         speech = speech[:, :speech_lengths.max()]

--
Gitblit v1.9.1