| | |
| | | # build tokenizer |
| | | tokenizer = kwargs.get("tokenizer", None) |
| | | if tokenizer is not None: |
| | | tokenizer_class = tables.tokenizer_classes.get(tokenizer.lower()) |
| | | tokenizer_class = tables.tokenizer_classes.get(tokenizer) |
| | | tokenizer = tokenizer_class(**kwargs["tokenizer_conf"]) |
| | | kwargs["tokenizer"] = tokenizer |
| | | kwargs["token_list"] = tokenizer.token_list |
| | |
| | | # build frontend |
| | | frontend = kwargs.get("frontend", None) |
| | | if frontend is not None: |
| | | frontend_class = tables.frontend_classes.get(frontend.lower()) |
| | | frontend_class = tables.frontend_classes.get(frontend) |
| | | frontend = frontend_class(**kwargs["frontend_conf"]) |
| | | kwargs["frontend"] = frontend |
| | | kwargs["input_size"] = frontend.output_size() |
| | | |
| | | # build model |
| | | model_class = tables.model_classes.get(kwargs["model"].lower()) |
| | | model_class = tables.model_classes.get(kwargs["model"]) |
| | | model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size) |
| | | model.eval() |
| | | model.to(device) |
| | |
| | | # build frontend |
| | | frontend = kwargs.get("frontend", None) |
| | | if frontend is not None: |
| | | frontend_class = tables.frontend_classes.get(frontend.lower()) |
| | | frontend_class = tables.frontend_classes.get(frontend) |
| | | frontend = frontend_class(**kwargs["frontend_conf"]) |
| | | |
| | | self.frontend = frontend |
| | |
| | | |
| | | tokenizer = kwargs.get("tokenizer", None) |
| | | if tokenizer is not None: |
| | | tokenizer_class = tables.tokenizer_classes.get(tokenizer.lower()) |
| | | tokenizer_class = tables.tokenizer_classes.get(tokenizer) |
| | | tokenizer = tokenizer_class(**kwargs["tokenizer_conf"]) |
| | | kwargs["tokenizer"] = tokenizer |
| | | |
| | | # build frontend if frontend is none None |
| | | frontend = kwargs.get("frontend", None) |
| | | if frontend is not None: |
| | | frontend_class = tables.frontend_classes.get(frontend.lower()) |
| | | frontend_class = tables.frontend_classes.get(frontend) |
| | | frontend = frontend_class(**kwargs["frontend_conf"]) |
| | | kwargs["frontend"] = frontend |
| | | kwargs["input_size"] = frontend.output_size() |
| | |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | # build model |
| | | model_class = tables.model_classes.get(kwargs["model"].lower()) |
| | | model_class = tables.model_classes.get(kwargs["model"]) |
| | | model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list)) |
| | | |
| | | |
| | |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | # dataset |
| | | dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset").lower()) |
| | | dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset")) |
| | | dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf")) |
| | | |
| | | # dataloader |
| | | batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler") |
| | | batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler.lower()) |
| | | batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler) |
| | | if batch_sampler is not None: |
| | | batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf")) |
| | | dataloader_tr = torch.utils.data.DataLoader(dataset_tr, |
| | |
| | | |
| | | @tables.register("dataset_classes", "AudioDataset") |
| | | class AudioDataset(torch.utils.data.Dataset): |
| | | """ |
| | | AudioDataset |
| | | """ |
| | | def __init__(self, |
| | | path, |
| | | index_ds: str = None, |
| | |
| | | float_pad_value: float = 0.0, |
| | | **kwargs): |
| | | super().__init__() |
| | | index_ds_class = tables.index_ds_classes.get(index_ds.lower()) |
| | | index_ds_class = tables.index_ds_classes.get(index_ds) |
| | | self.index_ds = index_ds_class(path) |
| | | preprocessor_speech = kwargs.get("preprocessor_speech", None) |
| | | if preprocessor_speech: |
| | | preprocessor_speech_class = tables.preprocessor_speech_classes.get(preprocessor_speech.lower()) |
| | | preprocessor_speech_class = tables.preprocessor_speech_classes.get(preprocessor_speech) |
| | | preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf")) |
| | | self.preprocessor_speech = preprocessor_speech |
| | | preprocessor_text = kwargs.get("preprocessor_text", None) |
| | | if preprocessor_text: |
| | | preprocessor_text_class = tables.preprocessor_text_classes.get(preprocessor_text.lower()) |
| | | preprocessor_text_class = tables.preprocessor_text_classes.get(preprocessor_text) |
| | | preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf")) |
| | | self.preprocessor_text = preprocessor_text |
| | | |
| | |
| | | |
| | | |
| | | self.embed = nn.Embedding(vocab_size, embed_unit) |
| | | encoder_class = tables.encoder_classes.get(encoder.lower()) |
| | | encoder_class = tables.encoder_classes.get(encoder) |
| | | encoder = encoder_class(**encoder_conf) |
| | | |
| | | self.decoder = nn.Linear(att_unit, punc_size) |
| | |
| | | super().__init__() |
| | | self.vad_opts = VADXOptions(**kwargs) |
| | | |
| | | encoder_class = tables.encoder_classes.get(encoder.lower()) |
| | | encoder_class = tables.encoder_classes.get(encoder) |
| | | encoder = encoder_class(**encoder_conf) |
| | | self.encoder = encoder |
| | | |
| | |
| | | super().__init__() |
| | | |
| | | if specaug is not None: |
| | | specaug_class = tables.specaug_classes.get(specaug.lower()) |
| | | specaug_class = tables.specaug_classes.get(specaug) |
| | | specaug = specaug_class(**specaug_conf) |
| | | if normalize is not None: |
| | | normalize_class = tables.normalize_classes.get(normalize.lower()) |
| | | normalize_class = tables.normalize_classes.get(normalize) |
| | | normalize = normalize_class(**normalize_conf) |
| | | encoder_class = tables.encoder_classes.get(encoder.lower()) |
| | | encoder_class = tables.encoder_classes.get(encoder) |
| | | encoder = encoder_class(input_size=input_size, **encoder_conf) |
| | | encoder_output_size = encoder.output_size() |
| | | predictor_class = tables.predictor_classes.get(predictor.lower()) |
| | | predictor_class = tables.predictor_classes.get(predictor) |
| | | predictor = predictor_class(**predictor_conf) |
| | | self.specaug = specaug |
| | | self.normalize = normalize |
| | |
| | | super().__init__() |
| | | |
| | | if specaug is not None: |
| | | specaug_class = tables.specaug_classes.get(specaug.lower()) |
| | | specaug_class = tables.specaug_classes.get(specaug) |
| | | specaug = specaug_class(**specaug_conf) |
| | | if normalize is not None: |
| | | normalize_class = tables.normalize_classes.get(normalize.lower()) |
| | | normalize_class = tables.normalize_classes.get(normalize) |
| | | normalize = normalize_class(**normalize_conf) |
| | | encoder_class = tables.encoder_classes.get(encoder.lower()) |
| | | encoder_class = tables.encoder_classes.get(encoder) |
| | | encoder = encoder_class(input_size=input_size, **encoder_conf) |
| | | encoder_output_size = encoder.output_size() |
| | | |
| | | if decoder is not None: |
| | | decoder_class = tables.decoder_classes.get(decoder.lower()) |
| | | decoder_class = tables.decoder_classes.get(decoder) |
| | | decoder = decoder_class( |
| | | vocab_size=vocab_size, |
| | | encoder_output_size=encoder_output_size, |
| | |
| | | odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf |
| | | ) |
| | | if predictor is not None: |
| | | predictor_class = tables.predictor_classes.get(predictor.lower()) |
| | | predictor_class = tables.predictor_classes.get(predictor) |
| | | predictor = predictor_class(**predictor_conf) |
| | | |
| | | # note that eos is the same as sos (equivalent ID) |
| | |
| | | seaco_decoder = kwargs.get("seaco_decoder", None) |
| | | if seaco_decoder is not None: |
| | | seaco_decoder_conf = kwargs.get("seaco_decoder_conf") |
| | | seaco_decoder_class = tables.decoder_classes.get(seaco_decoder.lower()) |
| | | seaco_decoder_class = tables.decoder_classes.get(seaco_decoder) |
| | | self.seaco_decoder = seaco_decoder_class( |
| | | vocab_size=self.vocab_size, |
| | | encoder_output_size=self.inner_dim, |
| | |
| | | super().__init__() |
| | | |
| | | if frontend is not None: |
| | | frontend_class = tables.frontend_classes.get_class(frontend.lower()) |
| | | frontend_class = tables.frontend_classes.get_class(frontend) |
| | | frontend = frontend_class(**frontend_conf) |
| | | if specaug is not None: |
| | | specaug_class = tables.specaug_classes.get_class(specaug.lower()) |
| | | specaug_class = tables.specaug_classes.get_class(specaug) |
| | | specaug = specaug_class(**specaug_conf) |
| | | if normalize is not None: |
| | | normalize_class = tables.normalize_classes.get_class(normalize.lower()) |
| | | normalize_class = tables.normalize_classes.get_class(normalize) |
| | | normalize = normalize_class(**normalize_conf) |
| | | encoder_class = tables.encoder_classes.get_class(encoder.lower()) |
| | | encoder_class = tables.encoder_classes.get_class(encoder) |
| | | encoder = encoder_class(input_size=input_size, **encoder_conf) |
| | | encoder_output_size = encoder.output_size() |
| | | if decoder is not None: |
| | | decoder_class = tables.decoder_classes.get_class(decoder.lower()) |
| | | decoder_class = tables.decoder_classes.get_class(decoder) |
| | | decoder = decoder_class( |
| | | vocab_size=vocab_size, |
| | | encoder_output_size=encoder_output_size, |
| | |
| | | import logging |
| | | import inspect |
| | | from dataclasses import dataclass |
| | | |
| | | import re |
| | | |
| | | @dataclass |
| | | class RegisterTables: |
| | |
| | | flag = key in classes_key |
| | | if classes_key.endswith("_meta") and flag: |
| | | print(f"----------- ** {classes_key.replace('_meta', '')} ** --------------") |
| | | headers = ["class name", "register name", "class location"] |
| | | headers = ["class name", "class location"] |
| | | metas = [] |
| | | for register_key, meta in classes_dict.items(): |
| | | metas.append(meta) |
| | |
| | | |
| | | registry = getattr(self, register_tables_key) |
| | | registry_key = key if key is not None else target_class.__name__ |
| | | registry_key = registry_key.lower() |
| | | # import pdb; pdb.set_trace() |
| | | |
| | | assert not registry_key in registry, "(key: {} / class: {}) has been registered already,in {}".format( |
| | | registry_key, target_class, register_tables_key) |
| | | |
| | |
| | | if not hasattr(self, register_tables_key_meta): |
| | | setattr(self, register_tables_key_meta, {}) |
| | | registry_meta = getattr(self, register_tables_key_meta) |
| | | # doc = target_class.__doc__ |
| | | class_file = inspect.getfile(target_class) |
| | | class_line = inspect.getsourcelines(target_class)[1] |
| | | meata_data = [f"{target_class.__name__}", f"{registry_key}", f"{class_file}:{class_line}"] |
| | | pattern = r'^.+/funasr/' |
| | | class_file = re.sub(pattern, 'funasr/', class_file) |
| | | meata_data = [f"{target_class.__name__}", f"{class_file}:{class_line}"] |
| | | # meata_data = [f"{target_class.__name__}", f"{registry_key}", f"{class_file}:{class_line}"] |
| | | registry_meta[registry_key] = meata_data |
| | | # print(f"Registering class: {class_file}:{class_line} - {target_class.__name__} as {registry_key}") |
| | | return target_class |