From 5231b54af843a486baf4649cfc45b4f06c9914c8 Mon Sep 17 00:00:00 2001
From: 北念 <lzr265946@alibaba-inc.com>
Date: 星期四, 09 二月 2023 19:37:21 +0800
Subject: [PATCH] add ContextualParaformer
---
funasr/tasks/asr.py | 8 +++++++-
1 files changed, 7 insertions(+), 1 deletions(-)
diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py
index 1b7f152..e62a748 100644
--- a/funasr/tasks/asr.py
+++ b/funasr/tasks/asr.py
@@ -37,8 +37,9 @@
)
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.models.decoder.transformer_decoder import TransformerDecoder
+from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
from funasr.models.e2e_asr import ESPnetASRModel
-from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer
+from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
from funasr.models.e2e_uni_asr import UniASR
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.encoder.conformer_encoder import ConformerEncoder
@@ -117,6 +118,7 @@
paraformer=Paraformer,
paraformer_bert=ParaformerBert,
bicif_paraformer=BiCifParaformer,
+ contextual_paraformer=ContextualParaformer,
),
type_check=AbsESPnetModel,
default="asr",
@@ -177,6 +179,7 @@
fsmn_scama_opt=FsmnDecoderSCAMAOpt,
paraformer_decoder_sanm=ParaformerSANMDecoder,
paraformer_decoder_san=ParaformerDecoderSAN,
+ contextual_paraformer_decoder=ContextualParaformerDecoder,
),
type_check=AbsDecoder,
default="rnn",
@@ -1098,5 +1101,8 @@
# decoder
var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
+ # bias_encoder
+ var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
return var_dict_torch_update
--
Gitblit v1.9.1