游雁
2024-02-19 1448e021accfdb03a381651cb5a8be6d1a6e8adf
funasr/datasets/audio_datasets/datasets.py
@@ -19,7 +19,7 @@
                  **kwargs):
        super().__init__()
        index_ds_class = tables.index_ds_classes.get(index_ds)
        self.index_ds = index_ds_class(path)
        self.index_ds = index_ds_class(path, **kwargs)
        preprocessor_speech = kwargs.get("preprocessor_speech", None)
        if preprocessor_speech:
            preprocessor_speech_class = tables.preprocessor_speech_classes.get(preprocessor_speech)
@@ -63,9 +63,14 @@
        target = item["target"]
        if self.preprocessor_text:
            target = self.preprocessor_text(target)
        ids = self.tokenizer.encode(target)
        if self.tokenizer:
            ids = self.tokenizer.encode(target)
            text = torch.tensor(ids, dtype=torch.int64)
        else:
            ids = target
            text = ids
        ids_lengths = len(ids)
        text, text_lengths = torch.tensor(ids, dtype=torch.int64), torch.tensor([ids_lengths], dtype=torch.int32)
        text_lengths = torch.tensor([ids_lengths], dtype=torch.int32)
        return {"speech": speech[0, :, :],
                "speech_lengths": speech_lengths,
@@ -83,11 +88,13 @@
                outputs[key].append(sample[key])
        for key, data_list in outputs.items():
            if data_list[0].dtype == torch.int64:
                pad_value = self.int_pad_value
            else:
                pad_value = self.float_pad_value
            outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
            if isinstance(data_list[0], torch.Tensor):
                if data_list[0].dtype == torch.int64:
                    pad_value = self.int_pad_value
                else:
                    pad_value = self.float_pad_value
                outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
        return outputs