| | |
| | | # from funasr.tokenizer.build_tokenizer import build_tokenizer |
| | | # from funasr.tokenizer.token_id_converter import TokenIDConverter |
| | | # from funasr.tokenizer.funtoken import build_tokenizer |
| | | |
| | | from funasr import AutoModel |
| | | |
| | | @hydra.main(config_name=None, version_base=None) |
| | | def main_hydra(kwargs: DictConfig): |
| | |
| | | if use_ddp or use_fsdp: |
| | | dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://') |
| | | torch.cuda.set_device(local_rank) |
| | | |
| | | device = kwargs.get("device", "cpu") |
| | | kwargs["device"] = "cpu" |
| | | model = AutoModel(**kwargs) |
| | | kwargs["device"] = device |
| | | model = model.model |
| | | tokenizer = kwargs["tokenizer"] |
| | | frontend = kwargs["frontend"] |
| | | |
| | | |
| | | |
| | | # save config.yaml |
| | | if (use_ddp or use_fsdp) and dist.get_rank() == 0 or not (use_ddp or use_fsdp) and local_rank == 0: |
| | |
| | | OmegaConf.save(config=kwargs, f=yaml_file) |
| | | logging.info("config.yaml is saved to: %s", yaml_file) |
| | | |
| | | tokenizer = kwargs.get("tokenizer", None) |
| | | if tokenizer is not None: |
| | | tokenizer_class = tables.tokenizer_classes.get(tokenizer) |
| | | tokenizer = tokenizer_class(**kwargs["tokenizer_conf"]) |
| | | kwargs["tokenizer"] = tokenizer |
| | | |
| | | |
| | | # build frontend if frontend is none None |
| | | frontend = kwargs.get("frontend", None) |
| | | if frontend is not None: |
| | | frontend_class = tables.frontend_classes.get(frontend) |
| | | frontend = frontend_class(**kwargs["frontend_conf"]) |
| | | kwargs["frontend"] = frontend |
| | | kwargs["input_size"] = frontend.output_size() |
| | | |
| | | |
| | | # build model |
| | | model_class = tables.model_classes.get(kwargs["model"]) |
| | | vocab_size = len(tokenizer.token_list) if hasattr(tokenizer, "token_list") else None |
| | | vocab_size = len(tokenizer.get_vocab()) if hasattr(tokenizer, "get_vocab") else vocab_size |
| | | model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size) |
| | | |
| | | |
| | | |
| | | |
| | | # init_param |
| | | init_param = kwargs.get("init_param", None) |