From 5231b54af843a486baf4649cfc45b4f06c9914c8 Mon Sep 17 00:00:00 2001
From: 北念 <lzr265946@alibaba-inc.com>
Date: 星期四, 09 二月 2023 19:37:21 +0800
Subject: [PATCH] add ContextualParaformer

---
 funasr/tasks/asr.py |    8 +++++++-
 1 files changed, 7 insertions(+), 1 deletions(-)

diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py
index 1b7f152..e62a748 100644
--- a/funasr/tasks/asr.py
+++ b/funasr/tasks/asr.py
@@ -37,8 +37,9 @@
 )
 from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
 from funasr.models.decoder.transformer_decoder import TransformerDecoder
+from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
 from funasr.models.e2e_asr import ESPnetASRModel
-from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer
+from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
 from funasr.models.e2e_uni_asr import UniASR
 from funasr.models.encoder.abs_encoder import AbsEncoder
 from funasr.models.encoder.conformer_encoder import ConformerEncoder
@@ -117,6 +118,7 @@
         paraformer=Paraformer,
         paraformer_bert=ParaformerBert,
         bicif_paraformer=BiCifParaformer,
+        contextual_paraformer=ContextualParaformer,
     ),
     type_check=AbsESPnetModel,
     default="asr",
@@ -177,6 +179,7 @@
         fsmn_scama_opt=FsmnDecoderSCAMAOpt,
         paraformer_decoder_sanm=ParaformerSANMDecoder,
         paraformer_decoder_san=ParaformerDecoderSAN,
+        contextual_paraformer_decoder=ContextualParaformerDecoder,
     ),
     type_check=AbsDecoder,
     default="rnn",
@@ -1098,5 +1101,8 @@
         # decoder
         var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
         var_dict_torch_update.update(var_dict_torch_update_local)
+        # bias_encoder
+        var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
+        var_dict_torch_update.update(var_dict_torch_update_local)
 
         return var_dict_torch_update

--
Gitblit v1.9.1