From e98e10639d90c55a4b7e498d0d87837ad9c4173d Mon Sep 17 00:00:00 2001 From: 游雁 <zhifu.gzf@alibaba-inc.com> Date: 星期三, 06 十二月 2023 19:42:02 +0800 Subject: [PATCH] funasr2 --- funasr/cli/models/paraformer.py | 4 ++-- 1 files changed, 2 insertions(+), 2 deletions(-) diff --git a/funasr/cli/models/paraformer.py b/funasr/cli/models/paraformer.py index ced1c23..ee8c0b4 100644 --- a/funasr/cli/models/paraformer.py +++ b/funasr/cli/models/paraformer.py @@ -594,7 +594,7 @@ for li in range(bsz): target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long() if target_num > 0: - input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0) + input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].to(input_mask.device), value=0) input_mask = input_mask.eq(1) input_mask = input_mask.masked_fill(~nonpad_positions, False) input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device) @@ -624,7 +624,7 @@ for li in range(bsz): target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long() if target_num > 0: - input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num], value=0) + input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].to(input_mask.device), value=0) input_mask = input_mask.eq(1) input_mask = input_mask.masked_fill(~nonpad_positions, False) input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device) -- Gitblit v1.9.1