From d77910eb6d171727f2350e45c31c91436c4c8891 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 11 十二月 2023 13:42:40 +0800
Subject: [PATCH] funasr2

---
 funasr/cli/models/paraformer.py |    3 +++
 1 files changed, 3 insertions(+), 0 deletions(-)

diff --git a/funasr/cli/models/paraformer.py b/funasr/cli/models/paraformer.py
index ee8c0b4..7ca80f5 100644
--- a/funasr/cli/models/paraformer.py
+++ b/funasr/cli/models/paraformer.py
@@ -193,6 +193,7 @@
 			self.decoder.embed = None
 
 		self.use_1st_decoder_loss = use_1st_decoder_loss
+		self.length_normalized_loss = length_normalized_loss
 	
 	def forward(
 		self,
@@ -302,6 +303,8 @@
 		stats["loss"] = torch.clone(loss.detach())
 		
 		# force_gatherable: to-device and to-tensor if scalar for DataParallel
+		if self.length_normalized_loss:
+			batch_size = (text_lengths + self.predictor_bias).sum()
 		loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
 		return loss, stats, weight
 	

--
Gitblit v1.9.1