From fa6f60fa762f271d096b8749f3cc9bfc61a6ed48 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 23 二月 2024 14:01:44 +0800
Subject: [PATCH] update

---
 funasr/datasets/llm_datasets/preprocessor.py |   40 +++++++++++++---------------------------
 1 files changed, 13 insertions(+), 27 deletions(-)

diff --git a/funasr/datasets/llm_datasets/preprocessor.py b/funasr/datasets/llm_datasets/preprocessor.py
index ab75140..9f20672 100644
--- a/funasr/datasets/llm_datasets/preprocessor.py
+++ b/funasr/datasets/llm_datasets/preprocessor.py
@@ -11,41 +11,27 @@
 from torch import nn
 import random
 import re
+import string
 from funasr.tokenizer.cleaner import TextCleaner
 from funasr.register import tables
 
 
-@tables.register("preprocessor_classes", "SpeechPreprocessSpeedPerturb")
-class SpeechPreprocessSpeedPerturb(nn.Module):
-	def __init__(self, speed_perturb: list=None, **kwargs):
-		super().__init__()
-		self.speed_perturb = speed_perturb
-		
-	def forward(self, waveform, fs, **kwargs):
-		if self.speed_perturb is None:
-			return waveform
-		speed = random.choice(self.speed_perturb)
-		if speed != 1.0:
-			if not isinstance(waveform, torch.Tensor):
-				waveform = torch.tensor(waveform)
-			waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
-				waveform.view(1, -1), fs, [['speed', str(speed)], ['rate', str(fs)]])
-			waveform = waveform.view(-1)
-			
-		return waveform
 
-
-@tables.register("preprocessor_classes", "TextPreprocessSegDict")
+@tables.register("preprocessor_classes", "TextPreprocessRemovePunctuation")
 class TextPreprocessSegDict(nn.Module):
-	def __init__(self, seg_dict: str = None,
-	             text_cleaner: Collection[str] = None,
-	             split_with_space: bool = False,
+	def __init__(self,
 	             **kwargs):
 		super().__init__()
 		
-		self.text_cleaner = TextCleaner(text_cleaner)
 	
 	def forward(self, text, **kwargs):
-		text = self.text_cleaner(text)
-		
-		return text
+		# 瀹氫箟鑻辨枃鏍囩偣绗﹀彿
+		en_punct = string.punctuation
+		# 瀹氫箟涓枃鏍囩偣绗﹀彿锛堥儴鍒嗗父鐢ㄧ殑锛�
+		cn_punct = '銆傦紵锛侊紝銆侊紱锛氣�溾�濃�樷�欙紙锛夈�娿�嬨�愩�戔�︹�旓綖路'
+		# 鍚堝苟鑻辨枃鍜屼腑鏂囨爣鐐圭鍙�
+		all_punct = en_punct + cn_punct
+		# 鍒涘缓姝e垯琛ㄨ揪寮忔ā寮忥紝鍖归厤浠讳綍鍦╝ll_punct涓殑瀛楃
+		punct_pattern = re.compile('[{}]'.format(re.escape(all_punct)))
+		# 浣跨敤姝e垯琛ㄨ揪寮忕殑sub鏂规硶鏇挎崲鎺夎繖浜涘瓧绗�
+		return punct_pattern.sub('', text)

--
Gitblit v1.9.1