From f62d6314b93416b301ca7bb0b62ccf9c4e6933e2 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期二, 11 七月 2023 00:52:21 +0800
Subject: [PATCH] update

---
 funasr/models/e2e_asr_paraformer.py |   12 +++---------
 1 files changed, 3 insertions(+), 9 deletions(-)

diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py
index 09af2cd..5a1a29b 100644
--- a/funasr/models/e2e_asr_paraformer.py
+++ b/funasr/models/e2e_asr_paraformer.py
@@ -10,7 +10,6 @@
 import torch
 import random
 import numpy as np
-from typeguard import check_argument_types
 
 from funasr.layers.abs_normalize import AbsNormalize
 from funasr.losses.label_smoothing_loss import (
@@ -80,7 +79,6 @@
             postencoder: Optional[AbsPostEncoder] = None,
             use_1st_decoder_loss: bool = False,
     ):
-        assert check_argument_types()
         assert 0.0 <= ctc_weight <= 1.0, ctc_weight
         assert 0.0 <= interctc_weight < 1.0, interctc_weight
 
@@ -242,7 +240,7 @@
             loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
 
         if self.use_1st_decoder_loss and pre_loss_att is not None:
-            loss = loss + pre_loss_att
+            loss = loss + (1 - self.ctc_weight) * pre_loss_att
 
         # Collect Attn branch stats
         stats["loss_att"] = loss_att.detach() if loss_att is not None else None
@@ -645,7 +643,6 @@
             postencoder: Optional[AbsPostEncoder] = None,
             use_1st_decoder_loss: bool = False,
     ):
-        assert check_argument_types()
         assert 0.0 <= ctc_weight <= 1.0, ctc_weight
         assert 0.0 <= interctc_weight < 1.0, interctc_weight
 
@@ -1160,8 +1157,8 @@
                                                                                            mask_chunk_predictor=mask_chunk_predictor,
                                                                                            target_label_length=None,
                                                                                            )
-        predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas[:, :-1],
-                                                                                             encoder_out_lens)
+        predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
+                                                                                             encoder_out_lens+1 if self.predictor.tail_threshold > 0.0 else encoder_out_lens)
 
         scama_mask = None
         if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
@@ -1255,7 +1252,6 @@
             preencoder: Optional[AbsPreEncoder] = None,
             postencoder: Optional[AbsPostEncoder] = None,
     ):
-        assert check_argument_types()
         assert 0.0 <= ctc_weight <= 1.0, ctc_weight
         assert 0.0 <= interctc_weight < 1.0, interctc_weight
 
@@ -1528,7 +1524,6 @@
             preencoder: Optional[AbsPreEncoder] = None,
             postencoder: Optional[AbsPostEncoder] = None,
     ):
-        assert check_argument_types()
         assert 0.0 <= ctc_weight <= 1.0, ctc_weight
         assert 0.0 <= interctc_weight < 1.0, interctc_weight
 
@@ -1806,7 +1801,6 @@
             preencoder: Optional[AbsPreEncoder] = None,
             postencoder: Optional[AbsPostEncoder] = None,
     ):
-        assert check_argument_types()
         assert 0.0 <= ctc_weight <= 1.0, ctc_weight
         assert 0.0 <= interctc_weight < 1.0, interctc_weight
 

--
Gitblit v1.9.1