From a8f0aad81de964941493c57351925071f3a8b733 Mon Sep 17 00:00:00 2001
From: sugarcase <shi.pengteng@163.com>
Date: 星期五, 27 九月 2024 14:16:28 +0800
Subject: [PATCH] fsmn_kws_mt finetune and inference adapt to right modelscope hub (#2113)

---
 funasr/models/fsmn_kws_mt/model.py |   30 +++++++-----------------------
 1 files changed, 7 insertions(+), 23 deletions(-)

diff --git a/funasr/models/fsmn_kws_mt/model.py b/funasr/models/fsmn_kws_mt/model.py
index c4645bb..3fa728a 100644
--- a/funasr/models/fsmn_kws_mt/model.py
+++ b/funasr/models/fsmn_kws_mt/model.py
@@ -41,8 +41,7 @@
         encoder_conf: Optional[Dict] = None,
         ctc_conf: Optional[Dict] = None,
         input_size: int = 360,
-        vocab_size: int = -1,
-        vocab_size2: int = -1,
+        vocab_size: list = [],
         ignore_id: int = -1,
         blank_id: int = 0,
         **kwargs,
@@ -63,14 +62,13 @@
         encoder_output_size2 = encoder.output_size2()
 
         ctc = CTC(
-            odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
+            odim=vocab_size[0], encoder_output_size=encoder_output_size, **ctc_conf
         )
         ctc2 = CTC(
-            odim=vocab_size2, encoder_output_size=encoder_output_size2, **ctc_conf
+            odim=vocab_size[1], encoder_output_size=encoder_output_size2, **ctc_conf
         )
 
         self.blank_id = blank_id
-        self.vocab_size = vocab_size
         self.ignore_id = ignore_id
 
         # self.frontend = frontend
@@ -208,7 +206,6 @@
         data_lengths=None,
         key: list=None,
         tokenizer=None,
-        tokenizer2=None,
         frontend=None,
         **kwargs,
     ):
@@ -217,14 +214,14 @@
         self.kws_decoder = KwsCtcPrefixDecoder(
             ctc=self.ctc,
             keywords=keywords,
-            token_list=tokenizer.token_list,
-            seg_dict=tokenizer.seg_dict,
+            token_list=tokenizer[0].token_list,
+            seg_dict=tokenizer[0].seg_dict,
         )
         self.kws_decoder2 = KwsCtcPrefixDecoder(
             ctc=self.ctc2,
             keywords=keywords,
-            token_list=tokenizer2.token_list,
-            seg_dict=tokenizer2.seg_dict,
+            token_list=tokenizer[1].token_list,
+            seg_dict=tokenizer[1].seg_dict,
         )
 
         meta_data = {}
@@ -314,12 +311,9 @@
         self,
         encoder: str = None,
         encoder_conf: Optional[Dict] = None,
-        ctc: str = None,
         ctc_conf: Optional[Dict] = None,
         ctc_weight: float = 1.0,
         input_size: int = 360,
-        vocab_size: int = -1,
-        vocab_size2: int = -1,
         blank_id: int = 0,
         **kwargs,
     ):
@@ -328,18 +322,8 @@
         encoder_class = tables.encoder_classes.get(encoder)
         encoder = encoder_class(**encoder_conf)
         encoder_output_size = encoder.output_size()
-
-        if ctc_conf is None:
-            ctc_conf = {}
-        ctc = CTC(
-            odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
-        )
-
         self.blank_id = blank_id
-        self.vocab_size = vocab_size
-        self.ctc_weight = ctc_weight
         self.encoder = encoder
-        self.ctc = ctc
 
         self.error_calculator = None
 

--
Gitblit v1.9.1