From 806a03609df033d61f824f1ab8527eb88fe837ad Mon Sep 17 00:00:00 2001 From: 游雁 <zhifu.gzf@alibaba-inc.com> Date: 星期三, 13 十二月 2023 19:43:13 +0800 Subject: [PATCH] funasr2 paraformer biciparaformer contextuaparaformer --- funasr/cli/models/paraformer.py | 7 +++++-- 1 files changed, 5 insertions(+), 2 deletions(-) diff --git a/funasr/cli/models/paraformer.py b/funasr/cli/models/paraformer.py index ced1c23..7ca80f5 100644 --- a/funasr/cli/models/paraformer.py +++ b/funasr/cli/models/paraformer.py @@ -193,6 +193,7 @@ self.decoder.embed = None self.use_1st_decoder_loss = use_1st_decoder_loss + self.length_normalized_loss = length_normalized_loss def forward( self, @@ -302,6 +303,8 @@ stats["loss"] = torch.clone(loss.detach()) # force_gatherable: to-device and to-tensor if scalar for DataParallel + if self.length_normalized_loss: + batch_size = (text_lengths + self.predictor_bias).sum() loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight @@ -594,7 +597,7 @@ for li in range(bsz): target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long() if target_num > 0: - input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0) + input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].to(input_mask.device), value=0) input_mask = input_mask.eq(1) input_mask = input_mask.masked_fill(~nonpad_positions, False) input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device) @@ -624,7 +627,7 @@ for li in range(bsz): target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long() if target_num > 0: - input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num], value=0) + input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].to(input_mask.device), value=0) input_mask = input_mask.eq(1) input_mask = input_mask.masked_fill(~nonpad_positions, False) input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device) -- Gitblit v1.9.1