From 835369d6315e96c1820326ed11ea4b999793720f Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期六, 13 一月 2024 22:42:18 +0800
Subject: [PATCH] funasr1.0 fix punc model

---
 funasr/models/ct_transformer/utils.py                           |  115 +++++++++++++++++++++++++++++++------
 examples/industrial_data_pretraining/ct_transformer/demo.py     |   10 +++
 examples/industrial_data_pretraining/paraformer-zh-spk/demo.py  |    2 
 examples/industrial_data_pretraining/paraformer-zh-spk/infer.sh |    2 
 funasr/download/download_from_hub.py                            |    2 
 funasr/models/ct_transformer/model.py                           |   10 ++
 examples/industrial_data_pretraining/ct_transformer/infer.sh    |    5 +
 examples/industrial_data_pretraining/bicif_paraformer/demo.py   |    4 
 examples/industrial_data_pretraining/bicif_paraformer/infer.sh  |    2 
 examples/industrial_data_pretraining/seaco_paraformer/infer.sh  |    2 
 examples/industrial_data_pretraining/seaco_paraformer/demo.py   |    2 
 11 files changed, 125 insertions(+), 31 deletions(-)

diff --git a/examples/industrial_data_pretraining/bicif_paraformer/demo.py b/examples/industrial_data_pretraining/bicif_paraformer/demo.py
index 84b0e80..4d921ea 100644
--- a/examples/industrial_data_pretraining/bicif_paraformer/demo.py
+++ b/examples/industrial_data_pretraining/bicif_paraformer/demo.py
@@ -10,7 +10,7 @@
                     vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
                     vad_model_revision="v2.0.1",
                     punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
-                    punc_model_revision="v2.0.0",
+                    punc_model_revision="v2.0.1",
                     spk_model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common",
                   )
 
@@ -23,7 +23,7 @@
                     vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
                     vad_model_revision="v2.0.1",
                     punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
-                    punc_model_revision="v2.0.0",
+                    punc_model_revision="v2.0.1",
                     spk_model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common",
                     spk_mode='punc_segment',
                   )
diff --git a/examples/industrial_data_pretraining/bicif_paraformer/infer.sh b/examples/industrial_data_pretraining/bicif_paraformer/infer.sh
index 04cb6f2..57c5838 100644
--- a/examples/industrial_data_pretraining/bicif_paraformer/infer.sh
+++ b/examples/industrial_data_pretraining/bicif_paraformer/infer.sh
@@ -4,7 +4,7 @@
 vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
 vad_model_revision="v2.0.0"
 punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
-punc_model_revision="v2.0.0"
+punc_model_revision="v2.0.1"
 
 python funasr/bin/inference.py \
 +model=${model} \
diff --git a/examples/industrial_data_pretraining/ct_transformer/demo.py b/examples/industrial_data_pretraining/ct_transformer/demo.py
index 58ebd2a..23965e0 100644
--- a/examples/industrial_data_pretraining/ct_transformer/demo.py
+++ b/examples/industrial_data_pretraining/ct_transformer/demo.py
@@ -5,7 +5,15 @@
 
 from funasr import AutoModel
 
-model = AutoModel(model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", model_revision="v2.0.0")
+model = AutoModel(model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", model_revision="v2.0.1")
+
+res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt")
+print(res)
+
+
+from funasr import AutoModel
+
+model = AutoModel(model="damo/punc_ct-transformer_cn-en-common-vocab471067-large", model_revision="v2.0.1")
 
 res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt")
 print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/ct_transformer/infer.sh b/examples/industrial_data_pretraining/ct_transformer/infer.sh
index a48d562..4b4e949 100644
--- a/examples/industrial_data_pretraining/ct_transformer/infer.sh
+++ b/examples/industrial_data_pretraining/ct_transformer/infer.sh
@@ -1,6 +1,9 @@
 
 model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
-model_revision="v2.0.0"
+model_revision="v2.0.1"
+
+model="damo/punc_ct-transformer_cn-en-common-vocab471067-large"
+model_revision="v2.0.1"
 
 python funasr/bin/inference.py \
 +model=${model} \
diff --git a/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py b/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py
index 774d757..fcf5f60 100644
--- a/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py
+++ b/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py
@@ -10,7 +10,7 @@
                   vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
                   vad_model_revision="v2.0.1",
                   punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
-                  punc_model_revision="v2.0.0",
+                  punc_model_revision="v2.0.1",
                   spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
                   spk_model_revision="v2.0.0"
                   )
diff --git a/examples/industrial_data_pretraining/paraformer-zh-spk/infer.sh b/examples/industrial_data_pretraining/paraformer-zh-spk/infer.sh
index a457401..63347b6 100644
--- a/examples/industrial_data_pretraining/paraformer-zh-spk/infer.sh
+++ b/examples/industrial_data_pretraining/paraformer-zh-spk/infer.sh
@@ -4,7 +4,7 @@
 vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
 vad_model_revision="v2.0.1"
 punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
-punc_model_revision="v2.0.0"
+punc_model_revision="v2.0.1"
 spk_model="damo/speech_campplus_sv_zh-cn_16k-common"
 spk_model_revision="v2.0.0"
 
diff --git a/examples/industrial_data_pretraining/seaco_paraformer/demo.py b/examples/industrial_data_pretraining/seaco_paraformer/demo.py
index 63f155e..3b5963a 100644
--- a/examples/industrial_data_pretraining/seaco_paraformer/demo.py
+++ b/examples/industrial_data_pretraining/seaco_paraformer/demo.py
@@ -10,7 +10,7 @@
                   vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
                   vad_model_revision="v2.0.1",
                   punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
-                  punc_model_revision="v2.0.0",
+                  punc_model_revision="v2.0.1",
                   )
 
 res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
diff --git a/examples/industrial_data_pretraining/seaco_paraformer/infer.sh b/examples/industrial_data_pretraining/seaco_paraformer/infer.sh
index 26eeee1..c46449f 100644
--- a/examples/industrial_data_pretraining/seaco_paraformer/infer.sh
+++ b/examples/industrial_data_pretraining/seaco_paraformer/infer.sh
@@ -4,7 +4,7 @@
 vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
 vad_model_revision="v2.0.1"
 punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
-punc_model_revision="v2.0.0"
+punc_model_revision="v2.0.1"
 
 python funasr/bin/inference.py \
 +model=${model} \
diff --git a/funasr/download/download_from_hub.py b/funasr/download/download_from_hub.py
index 946572f..27bd79d 100644
--- a/funasr/download/download_from_hub.py
+++ b/funasr/download/download_from_hub.py
@@ -37,6 +37,8 @@
 		kwargs["model"] = cfg["model"]
 		if os.path.exists(os.path.join(model_or_path, "am.mvn")):
 			kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
+		if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
+			kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
 	else:# configuration.json
 		assert os.path.exists(os.path.join(model_or_path, "configuration.json"))
 		with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f:
diff --git a/funasr/models/ct_transformer/model.py b/funasr/models/ct_transformer/model.py
index fbf1804..d843686 100644
--- a/funasr/models/ct_transformer/model.py
+++ b/funasr/models/ct_transformer/model.py
@@ -225,8 +225,14 @@
         # text = data_in[0]
         # text_lengths = data_lengths[0] if data_lengths is not None else None
         split_size = kwargs.get("split_size", 20)
-        
-        tokens = split_words(text)
+
+        jieba_usr_dict = kwargs.get("jieba_usr_dict", None)
+        if jieba_usr_dict and isinstance(jieba_usr_dict, str):
+            import jieba
+            jieba.load_userdict(jieba_usr_dict)
+            jieba_usr_dict = jieba
+            kwargs["jieba_usr_dict"] = "jieba_usr_dict"
+        tokens = split_words(text, jieba_usr_dict=jieba_usr_dict)
         tokens_int = tokenizer.encode(tokens)
 
         mini_sentences = split_to_mini_sentence(tokens, split_size)
diff --git a/funasr/models/ct_transformer/utils.py b/funasr/models/ct_transformer/utils.py
index a4a00e0..917f2e0 100644
--- a/funasr/models/ct_transformer/utils.py
+++ b/funasr/models/ct_transformer/utils.py
@@ -1,4 +1,4 @@
-
+import re
 
 def split_to_mini_sentence(words: list, word_limit: int = 20):
     assert word_limit > 1
@@ -14,23 +14,98 @@
     return sentences
 
 
-def split_words(text: str):
-    words = []
-    segs = text.split()
-    for seg in segs:
-        # There is no space in seg.
-        current_word = ""
-        for c in seg:
-            if len(c.encode()) == 1:
-                # This is an ASCII char.
-                current_word += c
+# def split_words(text: str, **kwargs):
+#     words = []
+#     segs = text.split()
+#     for seg in segs:
+#         # There is no space in seg.
+#         current_word = ""
+#         for c in seg:
+#             if len(c.encode()) == 1:
+#                 # This is an ASCII char.
+#                 current_word += c
+#             else:
+#                 # This is a Chinese char.
+#                 if len(current_word) > 0:
+#                     words.append(current_word)
+#                     current_word = ""
+#                 words.append(c)
+#         if len(current_word) > 0:
+#             words.append(current_word)
+#
+#     return words
+
+def split_words(text: str, jieba_usr_dict=None, **kwargs):
+    if jieba_usr_dict:
+        input_list = text.split()
+        token_list_all = []
+        langauge_list = []
+        token_list_tmp = []
+        language_flag = None
+        for token in input_list:
+            if isEnglish(token) and language_flag == 'Chinese':
+                token_list_all.append(token_list_tmp)
+                langauge_list.append('Chinese')
+                token_list_tmp = []
+            elif not isEnglish(token) and language_flag == 'English':
+                token_list_all.append(token_list_tmp)
+                langauge_list.append('English')
+                token_list_tmp = []
+
+            token_list_tmp.append(token)
+
+            if isEnglish(token):
+                language_flag = 'English'
             else:
-                # This is a Chinese char.
-                if len(current_word) > 0:
-                    words.append(current_word)
-                    current_word = ""
-                words.append(c)
-        if len(current_word) > 0:
-            words.append(current_word)
-    
-    return words
+                language_flag = 'Chinese'
+
+        if token_list_tmp:
+            token_list_all.append(token_list_tmp)
+            langauge_list.append(language_flag)
+
+        result_list = []
+        for token_list_tmp, language_flag in zip(token_list_all, langauge_list):
+            if language_flag == 'English':
+                result_list.extend(token_list_tmp)
+            else:
+                seg_list = jieba_usr_dict.cut(join_chinese_and_english(token_list_tmp), HMM=False)
+                result_list.extend(seg_list)
+
+        return result_list
+
+    else:
+        words = []
+        segs = text.split()
+        for seg in segs:
+            # There is no space in seg.
+            current_word = ""
+            for c in seg:
+                if len(c.encode()) == 1:
+                    # This is an ASCII char.
+                    current_word += c
+                else:
+                    # This is a Chinese char.
+                    if len(current_word) > 0:
+                        words.append(current_word)
+                        current_word = ""
+                    words.append(c)
+            if len(current_word) > 0:
+                words.append(current_word)
+        return words
+
+def isEnglish(text:str):
+    if re.search('^[a-zA-Z\']+$', text):
+        return True
+    else:
+        return False
+
+def join_chinese_and_english(input_list):
+    line = ''
+    for token in input_list:
+        if isEnglish(token):
+            line = line + ' ' + token
+        else:
+            line = line + token
+
+    line = line.strip()
+    return line

--
Gitblit v1.9.1