游雁
2024-02-23 fa6f60fa762f271d096b8749f3cc9bfc61a6ed48
funasr/datasets/llm_datasets/preprocessor.py
@@ -11,41 +11,27 @@
from torch import nn
import random
import re
import string
from funasr.tokenizer.cleaner import TextCleaner
from funasr.register import tables
@tables.register("preprocessor_classes", "SpeechPreprocessSpeedPerturb")
class SpeechPreprocessSpeedPerturb(nn.Module):
   def __init__(self, speed_perturb: list=None, **kwargs):
      super().__init__()
      self.speed_perturb = speed_perturb
      
   def forward(self, waveform, fs, **kwargs):
      if self.speed_perturb is None:
         return waveform
      speed = random.choice(self.speed_perturb)
      if speed != 1.0:
         if not isinstance(waveform, torch.Tensor):
            waveform = torch.tensor(waveform)
         waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
            waveform.view(1, -1), fs, [['speed', str(speed)], ['rate', str(fs)]])
         waveform = waveform.view(-1)
      return waveform
@tables.register("preprocessor_classes", "TextPreprocessSegDict")
@tables.register("preprocessor_classes", "TextPreprocessRemovePunctuation")
class TextPreprocessSegDict(nn.Module):
   def __init__(self, seg_dict: str = None,
                text_cleaner: Collection[str] = None,
                split_with_space: bool = False,
   def __init__(self,
                **kwargs):
      super().__init__()
      
      self.text_cleaner = TextCleaner(text_cleaner)
   
   def forward(self, text, **kwargs):
      text = self.text_cleaner(text)
      return text
      # 定义英文标点符号
      en_punct = string.punctuation
      # 定义中文标点符号(部分常用的)
      cn_punct = '。?!,、;:“”‘’()《》【】…—~·'
      # 合并英文和中文标点符号
      all_punct = en_punct + cn_punct
      # 创建正则表达式模式,匹配任何在all_punct中的字符
      punct_pattern = re.compile('[{}]'.format(re.escape(all_punct)))
      # 使用正则表达式的sub方法替换掉这些字符
      return punct_pattern.sub('', text)