From 835369d6315e96c1820326ed11ea4b999793720f Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期六, 13 一月 2024 22:42:18 +0800
Subject: [PATCH] funasr1.0 fix punc model
---
funasr/models/ct_transformer/utils.py | 115 +++++++++++++++++++++++++++++++------
examples/industrial_data_pretraining/ct_transformer/demo.py | 10 +++
examples/industrial_data_pretraining/paraformer-zh-spk/demo.py | 2
examples/industrial_data_pretraining/paraformer-zh-spk/infer.sh | 2
funasr/download/download_from_hub.py | 2
funasr/models/ct_transformer/model.py | 10 ++
examples/industrial_data_pretraining/ct_transformer/infer.sh | 5 +
examples/industrial_data_pretraining/bicif_paraformer/demo.py | 4
examples/industrial_data_pretraining/bicif_paraformer/infer.sh | 2
examples/industrial_data_pretraining/seaco_paraformer/infer.sh | 2
examples/industrial_data_pretraining/seaco_paraformer/demo.py | 2
11 files changed, 125 insertions(+), 31 deletions(-)
diff --git a/examples/industrial_data_pretraining/bicif_paraformer/demo.py b/examples/industrial_data_pretraining/bicif_paraformer/demo.py
index 84b0e80..4d921ea 100644
--- a/examples/industrial_data_pretraining/bicif_paraformer/demo.py
+++ b/examples/industrial_data_pretraining/bicif_paraformer/demo.py
@@ -10,7 +10,7 @@
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_model_revision="v2.0.1",
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
- punc_model_revision="v2.0.0",
+ punc_model_revision="v2.0.1",
spk_model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common",
)
@@ -23,7 +23,7 @@
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_model_revision="v2.0.1",
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
- punc_model_revision="v2.0.0",
+ punc_model_revision="v2.0.1",
spk_model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common",
spk_mode='punc_segment',
)
diff --git a/examples/industrial_data_pretraining/bicif_paraformer/infer.sh b/examples/industrial_data_pretraining/bicif_paraformer/infer.sh
index 04cb6f2..57c5838 100644
--- a/examples/industrial_data_pretraining/bicif_paraformer/infer.sh
+++ b/examples/industrial_data_pretraining/bicif_paraformer/infer.sh
@@ -4,7 +4,7 @@
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
vad_model_revision="v2.0.0"
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
-punc_model_revision="v2.0.0"
+punc_model_revision="v2.0.1"
python funasr/bin/inference.py \
+model=${model} \
diff --git a/examples/industrial_data_pretraining/ct_transformer/demo.py b/examples/industrial_data_pretraining/ct_transformer/demo.py
index 58ebd2a..23965e0 100644
--- a/examples/industrial_data_pretraining/ct_transformer/demo.py
+++ b/examples/industrial_data_pretraining/ct_transformer/demo.py
@@ -5,7 +5,15 @@
from funasr import AutoModel
-model = AutoModel(model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", model_revision="v2.0.0")
+model = AutoModel(model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", model_revision="v2.0.1")
+
+res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt")
+print(res)
+
+
+from funasr import AutoModel
+
+model = AutoModel(model="damo/punc_ct-transformer_cn-en-common-vocab471067-large", model_revision="v2.0.1")
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt")
print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/ct_transformer/infer.sh b/examples/industrial_data_pretraining/ct_transformer/infer.sh
index a48d562..4b4e949 100644
--- a/examples/industrial_data_pretraining/ct_transformer/infer.sh
+++ b/examples/industrial_data_pretraining/ct_transformer/infer.sh
@@ -1,6 +1,9 @@
model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
-model_revision="v2.0.0"
+model_revision="v2.0.1"
+
+model="damo/punc_ct-transformer_cn-en-common-vocab471067-large"
+model_revision="v2.0.1"
python funasr/bin/inference.py \
+model=${model} \
diff --git a/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py b/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py
index 774d757..fcf5f60 100644
--- a/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py
+++ b/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py
@@ -10,7 +10,7 @@
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_model_revision="v2.0.1",
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
- punc_model_revision="v2.0.0",
+ punc_model_revision="v2.0.1",
spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
spk_model_revision="v2.0.0"
)
diff --git a/examples/industrial_data_pretraining/paraformer-zh-spk/infer.sh b/examples/industrial_data_pretraining/paraformer-zh-spk/infer.sh
index a457401..63347b6 100644
--- a/examples/industrial_data_pretraining/paraformer-zh-spk/infer.sh
+++ b/examples/industrial_data_pretraining/paraformer-zh-spk/infer.sh
@@ -4,7 +4,7 @@
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
vad_model_revision="v2.0.1"
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
-punc_model_revision="v2.0.0"
+punc_model_revision="v2.0.1"
spk_model="damo/speech_campplus_sv_zh-cn_16k-common"
spk_model_revision="v2.0.0"
diff --git a/examples/industrial_data_pretraining/seaco_paraformer/demo.py b/examples/industrial_data_pretraining/seaco_paraformer/demo.py
index 63f155e..3b5963a 100644
--- a/examples/industrial_data_pretraining/seaco_paraformer/demo.py
+++ b/examples/industrial_data_pretraining/seaco_paraformer/demo.py
@@ -10,7 +10,7 @@
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_model_revision="v2.0.1",
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
- punc_model_revision="v2.0.0",
+ punc_model_revision="v2.0.1",
)
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
diff --git a/examples/industrial_data_pretraining/seaco_paraformer/infer.sh b/examples/industrial_data_pretraining/seaco_paraformer/infer.sh
index 26eeee1..c46449f 100644
--- a/examples/industrial_data_pretraining/seaco_paraformer/infer.sh
+++ b/examples/industrial_data_pretraining/seaco_paraformer/infer.sh
@@ -4,7 +4,7 @@
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
vad_model_revision="v2.0.1"
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
-punc_model_revision="v2.0.0"
+punc_model_revision="v2.0.1"
python funasr/bin/inference.py \
+model=${model} \
diff --git a/funasr/download/download_from_hub.py b/funasr/download/download_from_hub.py
index 946572f..27bd79d 100644
--- a/funasr/download/download_from_hub.py
+++ b/funasr/download/download_from_hub.py
@@ -37,6 +37,8 @@
kwargs["model"] = cfg["model"]
if os.path.exists(os.path.join(model_or_path, "am.mvn")):
kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
+ if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
+ kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
else:# configuration.json
assert os.path.exists(os.path.join(model_or_path, "configuration.json"))
with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f:
diff --git a/funasr/models/ct_transformer/model.py b/funasr/models/ct_transformer/model.py
index fbf1804..d843686 100644
--- a/funasr/models/ct_transformer/model.py
+++ b/funasr/models/ct_transformer/model.py
@@ -225,8 +225,14 @@
# text = data_in[0]
# text_lengths = data_lengths[0] if data_lengths is not None else None
split_size = kwargs.get("split_size", 20)
-
- tokens = split_words(text)
+
+ jieba_usr_dict = kwargs.get("jieba_usr_dict", None)
+ if jieba_usr_dict and isinstance(jieba_usr_dict, str):
+ import jieba
+ jieba.load_userdict(jieba_usr_dict)
+ jieba_usr_dict = jieba
+ kwargs["jieba_usr_dict"] = "jieba_usr_dict"
+ tokens = split_words(text, jieba_usr_dict=jieba_usr_dict)
tokens_int = tokenizer.encode(tokens)
mini_sentences = split_to_mini_sentence(tokens, split_size)
diff --git a/funasr/models/ct_transformer/utils.py b/funasr/models/ct_transformer/utils.py
index a4a00e0..917f2e0 100644
--- a/funasr/models/ct_transformer/utils.py
+++ b/funasr/models/ct_transformer/utils.py
@@ -1,4 +1,4 @@
-
+import re
def split_to_mini_sentence(words: list, word_limit: int = 20):
assert word_limit > 1
@@ -14,23 +14,98 @@
return sentences
-def split_words(text: str):
- words = []
- segs = text.split()
- for seg in segs:
- # There is no space in seg.
- current_word = ""
- for c in seg:
- if len(c.encode()) == 1:
- # This is an ASCII char.
- current_word += c
+# def split_words(text: str, **kwargs):
+# words = []
+# segs = text.split()
+# for seg in segs:
+# # There is no space in seg.
+# current_word = ""
+# for c in seg:
+# if len(c.encode()) == 1:
+# # This is an ASCII char.
+# current_word += c
+# else:
+# # This is a Chinese char.
+# if len(current_word) > 0:
+# words.append(current_word)
+# current_word = ""
+# words.append(c)
+# if len(current_word) > 0:
+# words.append(current_word)
+#
+# return words
+
+def split_words(text: str, jieba_usr_dict=None, **kwargs):
+ if jieba_usr_dict:
+ input_list = text.split()
+ token_list_all = []
+ langauge_list = []
+ token_list_tmp = []
+ language_flag = None
+ for token in input_list:
+ if isEnglish(token) and language_flag == 'Chinese':
+ token_list_all.append(token_list_tmp)
+ langauge_list.append('Chinese')
+ token_list_tmp = []
+ elif not isEnglish(token) and language_flag == 'English':
+ token_list_all.append(token_list_tmp)
+ langauge_list.append('English')
+ token_list_tmp = []
+
+ token_list_tmp.append(token)
+
+ if isEnglish(token):
+ language_flag = 'English'
else:
- # This is a Chinese char.
- if len(current_word) > 0:
- words.append(current_word)
- current_word = ""
- words.append(c)
- if len(current_word) > 0:
- words.append(current_word)
-
- return words
+ language_flag = 'Chinese'
+
+ if token_list_tmp:
+ token_list_all.append(token_list_tmp)
+ langauge_list.append(language_flag)
+
+ result_list = []
+ for token_list_tmp, language_flag in zip(token_list_all, langauge_list):
+ if language_flag == 'English':
+ result_list.extend(token_list_tmp)
+ else:
+ seg_list = jieba_usr_dict.cut(join_chinese_and_english(token_list_tmp), HMM=False)
+ result_list.extend(seg_list)
+
+ return result_list
+
+ else:
+ words = []
+ segs = text.split()
+ for seg in segs:
+ # There is no space in seg.
+ current_word = ""
+ for c in seg:
+ if len(c.encode()) == 1:
+ # This is an ASCII char.
+ current_word += c
+ else:
+ # This is a Chinese char.
+ if len(current_word) > 0:
+ words.append(current_word)
+ current_word = ""
+ words.append(c)
+ if len(current_word) > 0:
+ words.append(current_word)
+ return words
+
+def isEnglish(text:str):
+ if re.search('^[a-zA-Z\']+$', text):
+ return True
+ else:
+ return False
+
+def join_chinese_and_english(input_list):
+ line = ''
+ for token in input_list:
+ if isEnglish(token):
+ line = line + ' ' + token
+ else:
+ line = line + token
+
+ line = line.strip()
+ return line
--
Gitblit v1.9.1