From d77910eb6d171727f2350e45c31c91436c4c8891 Mon Sep 17 00:00:00 2001 From: 游雁 <zhifu.gzf@alibaba-inc.com> Date: 星期一, 11 十二月 2023 13:42:40 +0800 Subject: [PATCH] funasr2 --- funasr/cli/models/paraformer.py | 3 +++ 1 files changed, 3 insertions(+), 0 deletions(-) diff --git a/funasr/cli/models/paraformer.py b/funasr/cli/models/paraformer.py index ee8c0b4..7ca80f5 100644 --- a/funasr/cli/models/paraformer.py +++ b/funasr/cli/models/paraformer.py @@ -193,6 +193,7 @@ self.decoder.embed = None self.use_1st_decoder_loss = use_1st_decoder_loss + self.length_normalized_loss = length_normalized_loss def forward( self, @@ -302,6 +303,8 @@ stats["loss"] = torch.clone(loss.detach()) # force_gatherable: to-device and to-tensor if scalar for DataParallel + if self.length_normalized_loss: + batch_size = (text_lengths + self.predictor_bias).sum() loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight -- Gitblit v1.9.1