From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交
---
funasr/datasets/llm_datasets/preprocessor.py | 49 ++++++++++++++++---------------------------------
1 files changed, 16 insertions(+), 33 deletions(-)
diff --git a/funasr/datasets/llm_datasets/preprocessor.py b/funasr/datasets/llm_datasets/preprocessor.py
index ab75140..b99255e 100644
--- a/funasr/datasets/llm_datasets/preprocessor.py
+++ b/funasr/datasets/llm_datasets/preprocessor.py
@@ -11,41 +11,24 @@
from torch import nn
import random
import re
+import string
from funasr.tokenizer.cleaner import TextCleaner
from funasr.register import tables
-@tables.register("preprocessor_classes", "SpeechPreprocessSpeedPerturb")
-class SpeechPreprocessSpeedPerturb(nn.Module):
- def __init__(self, speed_perturb: list=None, **kwargs):
- super().__init__()
- self.speed_perturb = speed_perturb
-
- def forward(self, waveform, fs, **kwargs):
- if self.speed_perturb is None:
- return waveform
- speed = random.choice(self.speed_perturb)
- if speed != 1.0:
- if not isinstance(waveform, torch.Tensor):
- waveform = torch.tensor(waveform)
- waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
- waveform.view(1, -1), fs, [['speed', str(speed)], ['rate', str(fs)]])
- waveform = waveform.view(-1)
-
- return waveform
+@tables.register("preprocessor_classes", "TextPreprocessRemovePunctuation")
+class TextPreprocessRemovePunctuation(nn.Module):
+ def __init__(self, **kwargs):
+ super().__init__()
-
-@tables.register("preprocessor_classes", "TextPreprocessSegDict")
-class TextPreprocessSegDict(nn.Module):
- def __init__(self, seg_dict: str = None,
- text_cleaner: Collection[str] = None,
- split_with_space: bool = False,
- **kwargs):
- super().__init__()
-
- self.text_cleaner = TextCleaner(text_cleaner)
-
- def forward(self, text, **kwargs):
- text = self.text_cleaner(text)
-
- return text
+ def forward(self, text, **kwargs):
+ # 瀹氫箟鑻辨枃鏍囩偣绗﹀彿
+ en_punct = string.punctuation
+ # 瀹氫箟涓枃鏍囩偣绗﹀彿锛堥儴鍒嗗父鐢ㄧ殑锛�
+ cn_punct = "銆傦紵锛侊紝銆侊紱锛氣�溾�濃�樷�欙紙锛夈�娿�嬨�愩�戔�︹�旓綖路"
+ # 鍚堝苟鑻辨枃鍜屼腑鏂囨爣鐐圭鍙�
+ all_punct = en_punct + cn_punct
+ # 鍒涘缓姝e垯琛ㄨ揪寮忔ā寮忥紝鍖归厤浠讳綍鍦╝ll_punct涓殑瀛楃
+ punct_pattern = re.compile("[{}]".format(re.escape(all_punct)))
+ # 浣跨敤姝e垯琛ㄨ揪寮忕殑sub鏂规硶鏇挎崲鎺夎繖浜涘瓧绗�
+ return punct_pattern.sub("", text)
--
Gitblit v1.9.1