| New file |
| | |
| | | import os |
| | | import json |
| | | import torch |
| | | import logging |
| | | import concurrent.futures |
| | | import librosa |
| | | import torch.distributed as dist |
| | | from typing import Collection |
| | | import torch |
| | | import torchaudio |
| | | from torch import nn |
| | | import random |
| | | import re |
| | | import string |
| | | from funasr.tokenizer.cleaner import TextCleaner |
| | | from funasr.register import tables |
| | | |
| | | |
| | | |
| | | @tables.register("preprocessor_classes", "TextPreprocessRemovePunctuation") |
| | | class TextPreprocessSegDict(nn.Module): |
| | | def __init__(self, |
| | | **kwargs): |
| | | super().__init__() |
| | | |
| | | |
| | | def forward(self, text, **kwargs): |
| | | # 定义英文标点符号 |
| | | en_punct = string.punctuation |
| | | # 定义中文标点符号(部分常用的) |
| | | cn_punct = '。?!,、;:“”‘’()《》【】…—~·' |
| | | # 合并英文和中文标点符号 |
| | | all_punct = en_punct + cn_punct |
| | | # 创建正则表达式模式,匹配任何在all_punct中的字符 |
| | | punct_pattern = re.compile('[{}]'.format(re.escape(all_punct))) |
| | | # 使用正则表达式的sub方法替换掉这些字符 |
| | | return punct_pattern.sub('', text) |