| | |
| | | from torch import nn |
| | | import random |
| | | import re |
| | | import string |
| | | from funasr.tokenizer.cleaner import TextCleaner |
| | | from funasr.register import tables |
| | | |
| | | |
| | | @tables.register("preprocessor_classes", "SpeechPreprocessSpeedPerturb") |
| | | class SpeechPreprocessSpeedPerturb(nn.Module): |
| | | def __init__(self, speed_perturb: list=None, **kwargs): |
| | | super().__init__() |
| | | self.speed_perturb = speed_perturb |
| | | |
| | | def forward(self, waveform, fs, **kwargs): |
| | | if self.speed_perturb is None: |
| | | return waveform |
| | | speed = random.choice(self.speed_perturb) |
| | | if speed != 1.0: |
| | | if not isinstance(waveform, torch.Tensor): |
| | | waveform = torch.tensor(waveform) |
| | | waveform, _ = torchaudio.sox_effects.apply_effects_tensor( |
| | | waveform.view(1, -1), fs, [['speed', str(speed)], ['rate', str(fs)]]) |
| | | waveform = waveform.view(-1) |
| | | |
| | | return waveform |
| | | |
| | | |
| | | @tables.register("preprocessor_classes", "TextPreprocessSegDict") |
| | | @tables.register("preprocessor_classes", "TextPreprocessRemovePunctuation") |
| | | class TextPreprocessSegDict(nn.Module): |
| | | def __init__(self, seg_dict: str = None, |
| | | text_cleaner: Collection[str] = None, |
| | | split_with_space: bool = False, |
| | | def __init__(self, |
| | | **kwargs): |
| | | super().__init__() |
| | | |
| | | self.text_cleaner = TextCleaner(text_cleaner) |
| | | |
| | | def forward(self, text, **kwargs): |
| | | text = self.text_cleaner(text) |
| | | |
| | | return text |
| | | # 定义英文标点符号 |
| | | en_punct = string.punctuation |
| | | # 定义中文标点符号(部分常用的) |
| | | cn_punct = '。?!,、;:“”‘’()《》【】…—~·' |
| | | # 合并英文和中文标点符号 |
| | | all_punct = en_punct + cn_punct |
| | | # 创建正则表达式模式,匹配任何在all_punct中的字符 |
| | | punct_pattern = re.compile('[{}]'.format(re.escape(all_punct))) |
| | | # 使用正则表达式的sub方法替换掉这些字符 |
| | | return punct_pattern.sub('', text) |