| | |
| | | class AutoModel: |
| | | |
| | | def __init__(self, **kwargs): |
| | | |
| | | log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) |
| | | logging.basicConfig(level=log_level) |
| | | |
| | | if not kwargs.get("disable_log", True): |
| | | tables.print() |
| | | |
| | |
| | | return cfg_item |
| | | |
| | | kwargs = to_plain_list(cfg) |
| | | log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) |
| | | |
| | | logging.basicConfig(level=log_level) |
| | | |
| | | if kwargs.get("debug", False): |
| | | import pdb |
| | |
| | | return cfg_item |
| | | |
| | | kwargs = to_plain_list(cfg) |
| | | log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) |
| | | |
| | | logging.basicConfig(level=log_level) |
| | | |
| | | if kwargs.get("debug", False): |
| | | import pdb |
| | |
| | | tokenizer=tokenizer, |
| | | ) |
| | | |
| | | if len(kwargs.get("data_type", [])) > 1: |
| | | if ( |
| | | isinstance(kwargs.get("data_type", None), (list, tuple)) |
| | | and len(kwargs.get("data_type", [])) > 1 |
| | | ): |
| | | audio_sample_list, text_token_int_list = audio_sample_list |
| | | text_token_int = text_token_int_list[0] |
| | | else: |
| | |
| | | ys_pad = torch.tensor(sos_int + text_token_int, dtype=torch.int64).to(kwargs["device"])[ |
| | | None, : |
| | | ] |
| | | ys_pad_lens = torch.tensor([len(text_token_int)], dtype=torch.int64).to( |
| | | ys_pad_lens = torch.tensor([len(sos_int + text_token_int)], dtype=torch.int64).to( |
| | | kwargs["device"] |
| | | )[None, :] |
| | | decoder_out = self.model.decoder( |
| | |
| | | # post process of one iteration |
| | | running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps) |
| | | # end detection |
| | | # if len(ended_hyps) > 0: |
| | | # print(f"ended_hyps: {ended_hyps}") |
| | | if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i): |
| | | logging.info(f"end detected at {i}") |
| | | break |