From 806a03609df033d61f824f1ab8527eb88fe837ad Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 13 十二月 2023 19:43:13 +0800
Subject: [PATCH] funasr2 paraformer biciparaformer contextuaparaformer

---
 funasr/cli/models/paraformer.py |    7 +++++--
 1 files changed, 5 insertions(+), 2 deletions(-)

diff --git a/funasr/cli/models/paraformer.py b/funasr/cli/models/paraformer.py
index ced1c23..7ca80f5 100644
--- a/funasr/cli/models/paraformer.py
+++ b/funasr/cli/models/paraformer.py
@@ -193,6 +193,7 @@
 			self.decoder.embed = None
 
 		self.use_1st_decoder_loss = use_1st_decoder_loss
+		self.length_normalized_loss = length_normalized_loss
 	
 	def forward(
 		self,
@@ -302,6 +303,8 @@
 		stats["loss"] = torch.clone(loss.detach())
 		
 		# force_gatherable: to-device and to-tensor if scalar for DataParallel
+		if self.length_normalized_loss:
+			batch_size = (text_lengths + self.predictor_bias).sum()
 		loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
 		return loss, stats, weight
 	
@@ -594,7 +597,7 @@
 			for li in range(bsz):
 				target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
 				if target_num > 0:
-					input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
+					input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].to(input_mask.device), value=0)
 			input_mask = input_mask.eq(1)
 			input_mask = input_mask.masked_fill(~nonpad_positions, False)
 			input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
@@ -624,7 +627,7 @@
 		for li in range(bsz):
 			target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
 			if target_num > 0:
-				input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num], value=0)
+				input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].to(input_mask.device), value=0)
 		input_mask = input_mask.eq(1)
 		input_mask = input_mask.masked_fill(~nonpad_positions, False)
 		input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)

--
Gitblit v1.9.1