| | |
| | | |
| | | # build model |
| | | model_class = tables.model_classes.get(kwargs["model"]) |
| | | model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list)) |
| | | vocab_size = len(tokenizer.token_list) if hasattr(tokenizer, "token_list") else None |
| | | vocab_size = len(tokenizer.get_vocab()) if hasattr(tokenizer, "get_vocab") else vocab_size |
| | | model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size) |
| | | |
| | | |
| | | |
| | |
| | | path=p, |
| | | ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True), |
| | | oss_bucket=kwargs.get("oss_bucket", None), |
| | | scope_map=kwargs.get("scope_map", None), |
| | | scope_map=kwargs.get("scope_map", []), |
| | | excludes=kwargs.get("excludes", None), |
| | | ) |
| | | else: |
| | | logging.info(f"Checkpoint does not exist, init randomly: {p}") |
| | | else: |
| | | elif kwargs.get("init", None): |
| | | initialize(model, kwargs.get("init", "kaiming_normal")) |
| | | else: |
| | | print("No initialize method") |
| | | |
| | | |
| | | # freeze_param |