kongdeqiang
2026-03-13 28ccfbfc51068a663a80764e14074df5edf2b5ba
funasr/datasets/llm_datasets/preprocessor.py
@@ -11,41 +11,24 @@
from torch import nn
import random
import re
import string
from funasr.tokenizer.cleaner import TextCleaner
from funasr.register import tables
@tables.register("preprocessor_classes", "SpeechPreprocessSpeedPerturb")
class SpeechPreprocessSpeedPerturb(nn.Module):
   def __init__(self, speed_perturb: list=None, **kwargs):
      super().__init__()
      self.speed_perturb = speed_perturb
   def forward(self, waveform, fs, **kwargs):
      if self.speed_perturb is None:
         return waveform
      speed = random.choice(self.speed_perturb)
      if speed != 1.0:
         if not isinstance(waveform, torch.Tensor):
            waveform = torch.tensor(waveform)
         waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
            waveform.view(1, -1), fs, [['speed', str(speed)], ['rate', str(fs)]])
         waveform = waveform.view(-1)
      return waveform
@tables.register("preprocessor_classes", "TextPreprocessRemovePunctuation")
class TextPreprocessRemovePunctuation(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
@tables.register("preprocessor_classes", "TextPreprocessSegDict")
class TextPreprocessSegDict(nn.Module):
   def __init__(self, seg_dict: str = None,
                text_cleaner: Collection[str] = None,
                split_with_space: bool = False,
                **kwargs):
      super().__init__()
      self.text_cleaner = TextCleaner(text_cleaner)
   def forward(self, text, **kwargs):
      text = self.text_cleaner(text)
      return text
    def forward(self, text, **kwargs):
        # 定义英文标点符号
        en_punct = string.punctuation
        # 定义中文标点符号(部分常用的)
        cn_punct = "。?!,、;:“”‘’()《》【】…—~·"
        # 合并英文和中文标点符号
        all_punct = en_punct + cn_punct
        # 创建正则表达式模式,匹配任何在all_punct中的字符
        punct_pattern = re.compile("[{}]".format(re.escape(all_punct)))
        # 使用正则表达式的sub方法替换掉这些字符
        return punct_pattern.sub("", text)