From 5608bee0accea5e12030f8e1b6f7d62eee4dd892 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 01 三月 2024 18:36:40 +0800
Subject: [PATCH] fixbug (#1412)
---
funasr/models/paraformer/model.py | 4 ++--
funasr/models/seaco_paraformer/model.py | 1 -
funasr/models/contextual_paraformer/model.py | 5 +----
funasr/models/paraformer_streaming/model.py | 5 +----
4 files changed, 4 insertions(+), 11 deletions(-)
diff --git a/funasr/models/contextual_paraformer/model.py b/funasr/models/contextual_paraformer/model.py
index 939d46d..49868a8 100644
--- a/funasr/models/contextual_paraformer/model.py
+++ b/funasr/models/contextual_paraformer/model.py
@@ -190,13 +190,10 @@
# 0. sampler
decoder_out_1st = None
if self.sampling_ratio > 0.0:
- if self.step_cur < 2:
- logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+
sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
pre_acoustic_embeds, contextual_info)
else:
- if self.step_cur < 2:
- logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
sematic_embeds = pre_acoustic_embeds
# 1. Forward decoder
diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py
index 90ce162..51c9bdb 100644
--- a/funasr/models/paraformer/model.py
+++ b/funasr/models/paraformer/model.py
@@ -154,8 +154,8 @@
self.predictor_bias = predictor_bias
self.sampling_ratio = sampling_ratio
self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
- # self.step_cur = 0
- #
+
+
self.share_embedding = share_embedding
if self.share_embedding:
self.decoder.embed = None
diff --git a/funasr/models/paraformer_streaming/model.py b/funasr/models/paraformer_streaming/model.py
index 45875a2..4cf20de 100644
--- a/funasr/models/paraformer_streaming/model.py
+++ b/funasr/models/paraformer_streaming/model.py
@@ -235,8 +235,7 @@
decoder_out_1st = None
pre_loss_att = None
if self.sampling_ratio > 0.0:
- if self.step_cur < 2:
- logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+
if self.use_1st_decoder_loss:
sematic_embeds, decoder_out_1st, pre_loss_att = \
self.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad,
@@ -246,8 +245,6 @@
self.sampler(encoder_out, encoder_out_lens, ys_pad,
ys_pad_lens, pre_acoustic_embeds, scama_mask)
else:
- if self.step_cur < 2:
- logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
sematic_embeds = pre_acoustic_embeds
# 1. Forward decoder
diff --git a/funasr/models/seaco_paraformer/model.py b/funasr/models/seaco_paraformer/model.py
index 21ad874..20b0cc8 100644
--- a/funasr/models/seaco_paraformer/model.py
+++ b/funasr/models/seaco_paraformer/model.py
@@ -130,7 +130,6 @@
dha_pad = kwargs.get("dha_pad")
batch_size = speech.shape[0]
- self.step_cur += 1
# for data-parallel
text = text[:, : text_lengths.max()]
speech = speech[:, :speech_lengths.max()]
--
Gitblit v1.9.1