From 2a66366be4c2715870e4859fd5a5db6e8a9dc00a Mon Sep 17 00:00:00 2001
From: chenmengzheAAA <123789350+chenmengzheAAA@users.noreply.github.com>
Date: 星期四, 14 九月 2023 19:00:17 +0800
Subject: [PATCH] Merge pull request #956 from alibaba-damo-academy/chenmengzheAAA-patch-4

---
 funasr/models/e2e_asr_paraformer.py |    8 +-------
 1 files changed, 1 insertions(+), 7 deletions(-)

diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py
index 686038e..e157454 100644
--- a/funasr/models/e2e_asr_paraformer.py
+++ b/funasr/models/e2e_asr_paraformer.py
@@ -10,7 +10,6 @@
 import torch
 import random
 import numpy as np
-from typeguard import check_argument_types
 
 from funasr.layers.abs_normalize import AbsNormalize
 from funasr.losses.label_smoothing_loss import (
@@ -80,7 +79,6 @@
             postencoder: Optional[AbsPostEncoder] = None,
             use_1st_decoder_loss: bool = False,
     ):
-        assert check_argument_types()
         assert 0.0 <= ctc_weight <= 1.0, ctc_weight
         assert 0.0 <= interctc_weight < 1.0, interctc_weight
 
@@ -645,7 +643,6 @@
             postencoder: Optional[AbsPostEncoder] = None,
             use_1st_decoder_loss: bool = False,
     ):
-        assert check_argument_types()
         assert 0.0 <= ctc_weight <= 1.0, ctc_weight
         assert 0.0 <= interctc_weight < 1.0, interctc_weight
 
@@ -1255,7 +1252,6 @@
             preencoder: Optional[AbsPreEncoder] = None,
             postencoder: Optional[AbsPostEncoder] = None,
     ):
-        assert check_argument_types()
         assert 0.0 <= ctc_weight <= 1.0, ctc_weight
         assert 0.0 <= interctc_weight < 1.0, interctc_weight
 
@@ -1528,7 +1524,6 @@
             preencoder: Optional[AbsPreEncoder] = None,
             postencoder: Optional[AbsPostEncoder] = None,
     ):
-        assert check_argument_types()
         assert 0.0 <= ctc_weight <= 1.0, ctc_weight
         assert 0.0 <= interctc_weight < 1.0, interctc_weight
 
@@ -1806,7 +1801,6 @@
             preencoder: Optional[AbsPreEncoder] = None,
             postencoder: Optional[AbsPostEncoder] = None,
     ):
-        assert check_argument_types()
         assert 0.0 <= ctc_weight <= 1.0, ctc_weight
         assert 0.0 <= interctc_weight < 1.0, interctc_weight
 
@@ -2113,7 +2107,7 @@
 
         return loss_att, acc_att, cer_att, wer_att, loss_pre
 
-    def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None):
+    def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None, clas_scale=1.0):
         if hw_list is None:
             # default hotword list
             hw_list = [torch.Tensor([self.sos]).long().to(encoder_out.device)]  # empty hotword list

--
Gitblit v1.9.1