From a8f0aad81de964941493c57351925071f3a8b733 Mon Sep 17 00:00:00 2001
From: sugarcase <shi.pengteng@163.com>
Date: 星期五, 27 九月 2024 14:16:28 +0800
Subject: [PATCH] fsmn_kws_mt finetune and inference adapt to right modelscope hub (#2113)
---
funasr/models/fsmn_kws_mt/model.py | 30 +++++++-----------------------
1 files changed, 7 insertions(+), 23 deletions(-)
diff --git a/funasr/models/fsmn_kws_mt/model.py b/funasr/models/fsmn_kws_mt/model.py
index c4645bb..3fa728a 100644
--- a/funasr/models/fsmn_kws_mt/model.py
+++ b/funasr/models/fsmn_kws_mt/model.py
@@ -41,8 +41,7 @@
encoder_conf: Optional[Dict] = None,
ctc_conf: Optional[Dict] = None,
input_size: int = 360,
- vocab_size: int = -1,
- vocab_size2: int = -1,
+ vocab_size: list = [],
ignore_id: int = -1,
blank_id: int = 0,
**kwargs,
@@ -63,14 +62,13 @@
encoder_output_size2 = encoder.output_size2()
ctc = CTC(
- odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
+ odim=vocab_size[0], encoder_output_size=encoder_output_size, **ctc_conf
)
ctc2 = CTC(
- odim=vocab_size2, encoder_output_size=encoder_output_size2, **ctc_conf
+ odim=vocab_size[1], encoder_output_size=encoder_output_size2, **ctc_conf
)
self.blank_id = blank_id
- self.vocab_size = vocab_size
self.ignore_id = ignore_id
# self.frontend = frontend
@@ -208,7 +206,6 @@
data_lengths=None,
key: list=None,
tokenizer=None,
- tokenizer2=None,
frontend=None,
**kwargs,
):
@@ -217,14 +214,14 @@
self.kws_decoder = KwsCtcPrefixDecoder(
ctc=self.ctc,
keywords=keywords,
- token_list=tokenizer.token_list,
- seg_dict=tokenizer.seg_dict,
+ token_list=tokenizer[0].token_list,
+ seg_dict=tokenizer[0].seg_dict,
)
self.kws_decoder2 = KwsCtcPrefixDecoder(
ctc=self.ctc2,
keywords=keywords,
- token_list=tokenizer2.token_list,
- seg_dict=tokenizer2.seg_dict,
+ token_list=tokenizer[1].token_list,
+ seg_dict=tokenizer[1].seg_dict,
)
meta_data = {}
@@ -314,12 +311,9 @@
self,
encoder: str = None,
encoder_conf: Optional[Dict] = None,
- ctc: str = None,
ctc_conf: Optional[Dict] = None,
ctc_weight: float = 1.0,
input_size: int = 360,
- vocab_size: int = -1,
- vocab_size2: int = -1,
blank_id: int = 0,
**kwargs,
):
@@ -328,18 +322,8 @@
encoder_class = tables.encoder_classes.get(encoder)
encoder = encoder_class(**encoder_conf)
encoder_output_size = encoder.output_size()
-
- if ctc_conf is None:
- ctc_conf = {}
- ctc = CTC(
- odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
- )
-
self.blank_id = blank_id
- self.vocab_size = vocab_size
- self.ctc_weight = ctc_weight
self.encoder = encoder
- self.ctc = ctc
self.error_calculator = None
--
Gitblit v1.9.1