From bc723ea200144bd6fa8a5dff4b9a780feda144fc Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 29 六月 2023 18:55:01 +0800
Subject: [PATCH] dcos
---
funasr/models/e2e_asr_paraformer.py | 8 +-------
1 files changed, 1 insertions(+), 7 deletions(-)
diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py
index fcd6503..5a1a29b 100644
--- a/funasr/models/e2e_asr_paraformer.py
+++ b/funasr/models/e2e_asr_paraformer.py
@@ -10,7 +10,6 @@
import torch
import random
import numpy as np
-from typeguard import check_argument_types
from funasr.layers.abs_normalize import AbsNormalize
from funasr.losses.label_smoothing_loss import (
@@ -80,7 +79,6 @@
postencoder: Optional[AbsPostEncoder] = None,
use_1st_decoder_loss: bool = False,
):
- assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight
@@ -242,7 +240,7 @@
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
if self.use_1st_decoder_loss and pre_loss_att is not None:
- loss = loss + pre_loss_att
+ loss = loss + (1 - self.ctc_weight) * pre_loss_att
# Collect Attn branch stats
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
@@ -645,7 +643,6 @@
postencoder: Optional[AbsPostEncoder] = None,
use_1st_decoder_loss: bool = False,
):
- assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight
@@ -1255,7 +1252,6 @@
preencoder: Optional[AbsPreEncoder] = None,
postencoder: Optional[AbsPostEncoder] = None,
):
- assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight
@@ -1528,7 +1524,6 @@
preencoder: Optional[AbsPreEncoder] = None,
postencoder: Optional[AbsPostEncoder] = None,
):
- assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight
@@ -1806,7 +1801,6 @@
preencoder: Optional[AbsPreEncoder] = None,
postencoder: Optional[AbsPostEncoder] = None,
):
- assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight
--
Gitblit v1.9.1