From 3c3754dcc7568e76fa7d4b2c4e14849f68cc6ee7 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期日, 28 五月 2023 23:46:01 +0800
Subject: [PATCH] update repo
---
funasr/models/e2e_asr_paraformer.py | 3 ++-
1 files changed, 2 insertions(+), 1 deletions(-)
diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py
index 060442c..82acef2 100644
--- a/funasr/models/e2e_asr_paraformer.py
+++ b/funasr/models/e2e_asr_paraformer.py
@@ -236,10 +236,11 @@
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
if self.use_1st_decoder_loss and pre_loss_att is not None:
- loss = loss + pre_loss_att
+ loss = loss + (1 - self.ctc_weight) * pre_loss_att
# Collect Attn branch stats
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+ stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None
stats["acc"] = acc_att
stats["cer"] = cer_att
stats["wer"] = wer_att
--
Gitblit v1.9.1