游雁
2023-11-16 4ace5a95b052d338947fc88809a440ccd55cf6b4
funasr/datasets/small_datasets/preprocessor.py
@@ -10,8 +10,6 @@
import numpy as np
import scipy.signal
import soundfile
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.text.build_tokenizer import build_tokenizer
from funasr.text.cleaner import TextCleaner
@@ -260,7 +258,6 @@
    def _speech_process(
            self, data: Dict[str, Union[str, np.ndarray]]
    ) -> Dict[str, Union[str, np.ndarray]]:
        assert check_argument_types()
        if self.speech_name in data:
            if self.train and (self.rirs is not None or self.noises is not None):
                speech = data[self.speech_name]
@@ -347,7 +344,6 @@
                speech = data[self.speech_name]
                ma = np.max(np.abs(speech))
                data[self.speech_name] = speech * self.speech_volume_normalize / ma
        assert check_return_type(data)
        return data
    def _text_process(
@@ -365,13 +361,11 @@
                tokens = self.tokenizer.text2tokens(text)
            text_ints = self.token_id_converter.tokens2ids(tokens)
            data[self.text_name] = np.array(text_ints, dtype=np.int64)
        assert check_return_type(data)
        return data
    def __call__(
            self, uid: str, data: Dict[str, Union[str, np.ndarray]]
    ) -> Dict[str, np.ndarray]:
        assert check_argument_types()
        data = self._speech_process(data)
        data = self._text_process(data)
@@ -439,7 +433,6 @@
                tokens = self.tokenizer.text2tokens(text)
            text_ints = self.token_id_converter.tokens2ids(tokens)
            data[self.text_name] = np.array(text_ints, dtype=np.int64)
        assert check_return_type(data)
        return data
@@ -496,13 +489,11 @@
                tokens = self.tokenizer.text2tokens(text)
                text_ints = self.token_id_converter.tokens2ids(tokens)
                data[text_n] = np.array(text_ints, dtype=np.int64)
        assert check_return_type(data)
        return data
    def __call__(
            self, uid: str, data: Dict[str, Union[str, np.ndarray]]
    ) -> Dict[str, np.ndarray]:
        assert check_argument_types()
        if self.speech_name in data:
            # Nothing now: candidates:
@@ -606,7 +597,6 @@
                tokens = self.tokenizer[i].text2tokens(text)
                text_ints = self.token_id_converter[i].tokens2ids(tokens)
                data[text_name] = np.array(text_ints, dtype=np.int64)
        assert check_return_type(data)
        return data
@@ -685,7 +675,6 @@
    def __call__(
            self, uid: str, data: Dict[str, Union[list, str, np.ndarray]]
    ) -> Dict[str, Union[list, np.ndarray]]:
        assert check_argument_types()
        # Split words.
        if isinstance(data[self.text_name], str):
            split_text = self.split_words(data[self.text_name])
@@ -820,7 +809,7 @@
def build_preprocess(args, train):
    if args.use_preprocessor:
    if not args.use_preprocessor:
        return None
    if args.task_name in ["asr", "data2vec", "diar", "sv"]:
        retval = CommonPreprocessor(
@@ -828,7 +817,7 @@
            token_type=args.token_type,
            token_list=args.token_list,
            bpemodel=args.bpemodel,
            non_linguistic_symbols=args.non_linguistic_symbols,
            non_linguistic_symbols=args.non_linguistic_symbols if hasattr(args, "non_linguistic_symbols") else None,
            text_cleaner=args.cleaner,
            g2p_type=args.g2p,
            split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
@@ -855,6 +844,19 @@
            text_name=text_names,
            non_linguistic_symbols=args.non_linguistic_symbols,
        )
    elif args.task_name == "lm":
        retval = LMPreprocessor(
            train=train,
            token_type=args.token_type,
            token_list=args.token_list,
            bpemodel=args.bpemodel,
            text_cleaner=args.cleaner,
            g2p_type=args.g2p,
            text_name="text",
            non_linguistic_symbols=args.non_linguistic_symbols,
            split_with_space=args.split_with_space,
            seg_dict_file=args.seg_dict_file
        )
    elif args.task_name == "vad":
        retval = None
    else: