From a8f0aad81de964941493c57351925071f3a8b733 Mon Sep 17 00:00:00 2001
From: sugarcase <shi.pengteng@163.com>
Date: 星期五, 27 九月 2024 14:16:28 +0800
Subject: [PATCH] fsmn_kws_mt finetune and inference adapt to right modelscope hub (#2113)

---
 examples/industrial_data_pretraining/fsmn_kws_mt/convert.sh          |    6 +-
 examples/industrial_data_pretraining/fsmn_kws_mt/demo.py             |    2 
 examples/industrial_data_pretraining/fsmn_kws_mt/infer.sh            |    2 
 examples/industrial_data_pretraining/fsmn_kws_mt/infer_from_local.sh |   12 ++---
 funasr/auto/auto_model.py                                            |    5 +-
 funasr/models/fsmn_kws_mt/model.py                                   |   30 +++-----------
 examples/industrial_data_pretraining/fsmn_kws_mt/finetune.sh         |   12 ++---
 examples/industrial_data_pretraining/fsmn_kws_mt/convert.py          |    4 --
 funasr/download/download_model_from_hub.py                           |    5 +-
 9 files changed, 28 insertions(+), 50 deletions(-)

diff --git a/examples/industrial_data_pretraining/fsmn_kws_mt/convert.py b/examples/industrial_data_pretraining/fsmn_kws_mt/convert.py
index e63e689..a6ef0f8 100644
--- a/examples/industrial_data_pretraining/fsmn_kws_mt/convert.py
+++ b/examples/industrial_data_pretraining/fsmn_kws_mt/convert.py
@@ -49,8 +49,6 @@
     copyfile(network_file, os.path.join(model_dir, 'origin.torch.pt'))
 
     model = FsmnKWSMTConvert(
-        vocab_size=configs['encoder_conf']['output_dim'],
-        vocab_size2=configs['encoder_conf']['output_dim2'],
         encoder='FSMNMTConvert',
         encoder_conf=configs['encoder_conf'],
         ctc_conf=configs['ctc_conf'],
@@ -82,8 +80,6 @@
     model_name="convert.torch.pt"
 ):
     model = FsmnKWSMTConvert(
-        vocab_size=configs['encoder_conf']['output_dim'],
-        vocab_size2=configs['encoder_conf']['output_dim2'],
         encoder='FSMNMTConvert',
         encoder_conf=configs['encoder_conf'],
         ctc_conf=configs['ctc_conf'],
diff --git a/examples/industrial_data_pretraining/fsmn_kws_mt/convert.sh b/examples/industrial_data_pretraining/fsmn_kws_mt/convert.sh
index 30e2eed..26a47fc 100644
--- a/examples/industrial_data_pretraining/fsmn_kws_mt/convert.sh
+++ b/examples/industrial_data_pretraining/fsmn_kws_mt/convert.sh
@@ -5,16 +5,16 @@
 local_path_root=${workspace}/modelscope_models
 mkdir -p ${local_path_root}
 
-local_path=${local_path_root}/speech_charctc_kws_phone-xiaoyun
+local_path=${local_path_root}/speech_charctc_kws_phone-xiaoyun_mt
 if [ ! -d "$local_path" ]; then
-    git clone https://www.modelscope.cn/iic/speech_charctc_kws_phone-xiaoyun.git ${local_path}
+    git clone https://www.modelscope.cn/iic/speech_charctc_kws_phone-xiaoyun_mt.git ${local_path}
 fi
 
 export PATH=${local_path}/runtime:$PATH
 export LD_LIBRARY_PATH=${local_path}/runtime:$LD_LIBRARY_PATH
 
 # finetune config file
-config=./conf/fsmn_4e_l10r2_280_200_fdim40_t2602_t4.yaml
+config=./conf/fsmn_4e_l10r2_250_128_fdim80_t2599_t4.yaml
 
 # finetune output checkpoint
 torch_nnet=exp/finetune_outputs/model.pt.avg10
diff --git a/examples/industrial_data_pretraining/fsmn_kws_mt/demo.py b/examples/industrial_data_pretraining/fsmn_kws_mt/demo.py
index 6bac47b..eef4d68 100644
--- a/examples/industrial_data_pretraining/fsmn_kws_mt/demo.py
+++ b/examples/industrial_data_pretraining/fsmn_kws_mt/demo.py
@@ -6,7 +6,7 @@
 from funasr import AutoModel
 
 model = AutoModel(
-    model="iic/speech_charctc_kws_phone-xiaoyun",
+    model="iic/speech_charctc_kws_phone-xiaoyun_mt",
     keywords="灏忎簯灏忎簯",
     output_dir="./outputs/debug",
     device='cpu'
diff --git a/examples/industrial_data_pretraining/fsmn_kws_mt/finetune.sh b/examples/industrial_data_pretraining/fsmn_kws_mt/finetune.sh
index 1e87021..56162d7 100755
--- a/examples/industrial_data_pretraining/fsmn_kws_mt/finetune.sh
+++ b/examples/industrial_data_pretraining/fsmn_kws_mt/finetune.sh
@@ -27,19 +27,19 @@
 # model_name from model_hub, or model_dir in local path
 
 ## option 1, download model automatically, unsupported currently
-model_name_or_model_dir="iic/speech_charctc_kws_phone-xiaoyun"
+model_name_or_model_dir="iic/speech_charctc_kws_phone-xiaoyun_mt"
 
 ## option 2, download model by git
 local_path_root=${workspace}/modelscope_models
 model_name_or_model_dir=${local_path_root}/${model_name_or_model_dir}
 if [ ! -d $model_name_or_model_dir ]; then
   mkdir -p ${model_name_or_model_dir}
-  git clone https://www.modelscope.cn/iic/speech_charctc_kws_phone-xiaoyun.git ${model_name_or_model_dir}
+  git clone https://www.modelscope.cn/iic/speech_charctc_kws_phone-xiaoyun_mt.git ${model_name_or_model_dir}
 fi
 
 config=fsmn_4e_l10r2_250_128_fdim80_t2599_t4.yaml
 token_list=${model_name_or_model_dir}/funasr/tokens_2599.txt
-token_list2=${model_name_or_model_dir}/funasr/tokens_xiaoyun_char.txt
+token_list2=${model_name_or_model_dir}/funasr/tokens_xiaoyun.txt
 lexicon_list=${model_name_or_model_dir}/funasr/lexicon.txt
 cmvn_file=${model_name_or_model_dir}/funasr/am.mvn.dim80_l2r2
 init_param="${model_name_or_model_dir}/funasr/basetrain_fsmn_4e_l10r2_250_128_fdim80_t2599.pt"
@@ -141,10 +141,8 @@
           --config-path="${output_dir}" \
           --config-name="config.yaml" \
           ++init_param="${output_dir}/${inference_checkpoint}" \
-          ++tokenizer_conf.token_list="${token_list}" \
-          ++tokenizer_conf.seg_dict="${lexicon_list}" \
-          ++tokenizer2_conf.token_list="${token_list2}" \
-          ++tokenizer2_conf.seg_dict="${lexicon_list}" \
+          ++token_lists='['''${token_list}''', '''${token_list2}''']' \
+          ++seg_dicts='['''${lexicon_list}''', '''${lexicon_list}''']' \
           ++frontend_conf.cmvn_file="${cmvn_file}" \
           ++keywords="\"$keywords_string"\" \
           ++input="${_logdir}/keys.${JOB}.scp" \
diff --git a/examples/industrial_data_pretraining/fsmn_kws_mt/infer.sh b/examples/industrial_data_pretraining/fsmn_kws_mt/infer.sh
index 6e03b89..3905b30 100644
--- a/examples/industrial_data_pretraining/fsmn_kws_mt/infer.sh
+++ b/examples/industrial_data_pretraining/fsmn_kws_mt/infer.sh
@@ -3,7 +3,7 @@
 
 # method1, inference from model hub
 
-model="iic/speech_charctc_kws_phone-xiaoyun"
+model="iic/speech_charctc_kws_phone-xiaoyun_mt"
 
 # for more input type, please ref to readme.md
 input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/kws_xiaoyunxiaoyun.wav"
diff --git a/examples/industrial_data_pretraining/fsmn_kws_mt/infer_from_local.sh b/examples/industrial_data_pretraining/fsmn_kws_mt/infer_from_local.sh
index 51d2312..f451d31 100644
--- a/examples/industrial_data_pretraining/fsmn_kws_mt/infer_from_local.sh
+++ b/examples/industrial_data_pretraining/fsmn_kws_mt/infer_from_local.sh
@@ -13,14 +13,14 @@
 # download model
 local_path_root=${workspace}/modelscope_models
 mkdir -p ${local_path_root}
-local_path=${local_path_root}/speech_charctc_kws_phone-xiaoyun
-git clone https://www.modelscope.cn/iic/speech_charctc_kws_phone-xiaoyun.git ${local_path}
+local_path=${local_path_root}/speech_charctc_kws_phone-xiaoyun_mt
+git clone https://www.modelscope.cn/iic/speech_charctc_kws_phone-xiaoyun_mt.git ${local_path}
 
 device="cuda:0" # "cuda:0" for gpu0, "cuda:1" for gpu1, "cpu"
 
 config="inference_fsmn_4e_l10r2_280_200_fdim40_t2602_t4.yaml"
 tokens="${local_path}/funasr/tokens_2602.txt"
-tokens2="${local_path}/funasr/tokens_xiaoyun_char.txt"
+tokens2="${local_path}/funasr/tokens_xiaoyun.txt"
 seg_dict="${local_path}/funasr/lexicon.txt"
 init_param="${local_path}/funasr/finetune_fsmn_4e_l10r2_280_200_fdim40_t2602_t4_xiaoyun_xiaoyun.pt"
 cmvn_file="${local_path}/funasr/am.mvn.dim40_l4r4"
@@ -34,10 +34,8 @@
 --config-name "${config}" \
 ++init_param="${init_param}" \
 ++frontend_conf.cmvn_file="${cmvn_file}" \
-++tokenizer_conf.token_list="${tokens}" \
-++tokenizer_conf.seg_dict="${seg_dict}" \
-++tokenizer2_conf.token_list="${tokens2}" \
-++tokenizer2_conf.seg_dict="${seg_dict}" \
+++token_lists='['''${tokens}''', '''${tokens2}''']' \
+++seg_dicts='['''${seg_dict}''', '''${seg_dict}''']' \
 ++input="${input}" \
 ++output_dir="${output_dir}" \
 ++device="${device}" \
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index e08cb2b..71f44b4 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -199,6 +199,7 @@
             tokenizers_build = []
             vocab_sizes = []
             token_lists = []
+
             ### === only for kws ===
             token_list_files = kwargs.get("token_lists", [])
             seg_dicts = kwargs.get("seg_dicts", [])
@@ -213,9 +214,9 @@
 
                 ### === only for kws ===
                 if len(token_list_files) > 1:
-                    tokenizer_conf.token_list = token_list_files[i]
+                    tokenizer_conf["token_list"] = token_list_files[i]
                 if len(seg_dicts) > 1:
-                    tokenizer_conf.seg_dict = seg_dicts[i]
+                    tokenizer_conf["seg_dict"] = seg_dicts[i]
                 ### === only for kws ===
 
                 tokenizer = tokenizer_class(**tokenizer_conf)
diff --git a/funasr/download/download_model_from_hub.py b/funasr/download/download_model_from_hub.py
index 8e51144..f7eea2a 100644
--- a/funasr/download/download_model_from_hub.py
+++ b/funasr/download/download_model_from_hub.py
@@ -162,6 +162,7 @@
     if isinstance(file_path_metas, dict):
         if isinstance(cfg, list):
             cfg.append({})
+
         for k, v in file_path_metas.items():
             if isinstance(v, str):
                 p = os.path.join(model_or_path, v)
@@ -186,8 +187,8 @@
                     if k not in cfg:
                         cfg[k] = []
                     if isinstance(vv, str):
-                        p = os.path.join(model_or_path, v)
-                        file_path_metas[i] = p
+                        p = os.path.join(model_or_path, vv)
+                        # file_path_metas[i] = p
                         if os.path.exists(p):
                             if isinstance(cfg[k], dict):
                                 cfg[k] = p
diff --git a/funasr/models/fsmn_kws_mt/model.py b/funasr/models/fsmn_kws_mt/model.py
index c4645bb..3fa728a 100644
--- a/funasr/models/fsmn_kws_mt/model.py
+++ b/funasr/models/fsmn_kws_mt/model.py
@@ -41,8 +41,7 @@
         encoder_conf: Optional[Dict] = None,
         ctc_conf: Optional[Dict] = None,
         input_size: int = 360,
-        vocab_size: int = -1,
-        vocab_size2: int = -1,
+        vocab_size: list = [],
         ignore_id: int = -1,
         blank_id: int = 0,
         **kwargs,
@@ -63,14 +62,13 @@
         encoder_output_size2 = encoder.output_size2()
 
         ctc = CTC(
-            odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
+            odim=vocab_size[0], encoder_output_size=encoder_output_size, **ctc_conf
         )
         ctc2 = CTC(
-            odim=vocab_size2, encoder_output_size=encoder_output_size2, **ctc_conf
+            odim=vocab_size[1], encoder_output_size=encoder_output_size2, **ctc_conf
         )
 
         self.blank_id = blank_id
-        self.vocab_size = vocab_size
         self.ignore_id = ignore_id
 
         # self.frontend = frontend
@@ -208,7 +206,6 @@
         data_lengths=None,
         key: list=None,
         tokenizer=None,
-        tokenizer2=None,
         frontend=None,
         **kwargs,
     ):
@@ -217,14 +214,14 @@
         self.kws_decoder = KwsCtcPrefixDecoder(
             ctc=self.ctc,
             keywords=keywords,
-            token_list=tokenizer.token_list,
-            seg_dict=tokenizer.seg_dict,
+            token_list=tokenizer[0].token_list,
+            seg_dict=tokenizer[0].seg_dict,
         )
         self.kws_decoder2 = KwsCtcPrefixDecoder(
             ctc=self.ctc2,
             keywords=keywords,
-            token_list=tokenizer2.token_list,
-            seg_dict=tokenizer2.seg_dict,
+            token_list=tokenizer[1].token_list,
+            seg_dict=tokenizer[1].seg_dict,
         )
 
         meta_data = {}
@@ -314,12 +311,9 @@
         self,
         encoder: str = None,
         encoder_conf: Optional[Dict] = None,
-        ctc: str = None,
         ctc_conf: Optional[Dict] = None,
         ctc_weight: float = 1.0,
         input_size: int = 360,
-        vocab_size: int = -1,
-        vocab_size2: int = -1,
         blank_id: int = 0,
         **kwargs,
     ):
@@ -328,18 +322,8 @@
         encoder_class = tables.encoder_classes.get(encoder)
         encoder = encoder_class(**encoder_conf)
         encoder_output_size = encoder.output_size()
-
-        if ctc_conf is None:
-            ctc_conf = {}
-        ctc = CTC(
-            odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
-        )
-
         self.blank_id = blank_id
-        self.vocab_size = vocab_size
-        self.ctc_weight = ctc_weight
         self.encoder = encoder
-        self.ctc = ctc
 
         self.error_calculator = None
 

--
Gitblit v1.9.1