kongdeqiang
5 天以前 28ccfbfc51068a663a80764e14074df5edf2b5ba
funasr/datasets/audio_datasets/preprocessor.py
@@ -17,67 +17,39 @@
@tables.register("preprocessor_classes", "SpeechPreprocessSpeedPerturb")
class SpeechPreprocessSpeedPerturb(nn.Module):
   def __init__(self, speed_perturb: list=None, **kwargs):
      super().__init__()
      self.speed_perturb = speed_perturb
   def forward(self, waveform, fs, **kwargs):
      if self.speed_perturb is None:
         return waveform
      speed = random.choice(self.speed_perturb)
      if speed != 1.0:
         waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
            torch.tensor(waveform).view(1, -1), fs, [['speed', str(speed)], ['rate', str(fs)]])
         waveform = waveform.view(-1)
      return waveform
    def __init__(self, speed_perturb: list = None, **kwargs):
        super().__init__()
        self.speed_perturb = speed_perturb
    def forward(self, waveform, fs, **kwargs):
        if self.speed_perturb is None:
            return waveform
        speed = random.choice(self.speed_perturb)
        if speed != 1.0:
            if not isinstance(waveform, torch.Tensor):
                waveform = torch.tensor(waveform)
            waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
                waveform.view(1, -1), fs, [["speed", str(speed)], ["rate", str(fs)]]
            )
            waveform = waveform.view(-1)
        return waveform
@tables.register("preprocessor_classes", "TextPreprocessSegDict")
class TextPreprocessSegDict(nn.Module):
   def __init__(self, seg_dict: str = None,
                text_cleaner: Collection[str] = None,
                split_with_space: bool = False,
                **kwargs):
      super().__init__()
      self.seg_dict = None
      if seg_dict is not None:
         self.seg_dict = {}
         with open(seg_dict, "r", encoding="utf8") as f:
            lines = f.readlines()
         for line in lines:
            s = line.strip().split()
            key = s[0]
            value = s[1:]
            self.seg_dict[key] = " ".join(value)
      self.text_cleaner = TextCleaner(text_cleaner)
      self.split_with_space = split_with_space
   def forward(self, text, **kwargs):
      if self.seg_dict is not None:
         text = self.text_cleaner(text)
         if self.split_with_space:
            tokens = text.strip().split(" ")
            if self.seg_dict is not None:
               text = seg_tokenize(tokens, self.seg_dict)
    def __init__(
        self,
        seg_dict: str = None,
        text_cleaner: Collection[str] = None,
        split_with_space: bool = False,
        **kwargs
    ):
        super().__init__()
      return text
        self.text_cleaner = TextCleaner(text_cleaner)
def seg_tokenize(txt, seg_dict):
   pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
   out_txt = ""
   for word in txt:
      word = word.lower()
      if word in seg_dict:
         out_txt += seg_dict[word] + " "
      else:
         if pattern.match(word):
            for char in word:
               if char in seg_dict:
                  out_txt += seg_dict[char] + " "
               else:
                  out_txt += "<unk>" + " "
         else:
            out_txt += "<unk>" + " "
   return out_txt.strip().split()
    def forward(self, text, **kwargs):
        text = self.text_cleaner(text)
        return text