From 557b913c5d78945e49cb2ac2bf254a2de40b6cd5 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 14 三月 2024 11:48:31 +0800
Subject: [PATCH] v1.0.16
---
funasr/bin/train_llm.py | 35 +++++++++++++----------------------
1 files changed, 13 insertions(+), 22 deletions(-)
diff --git a/funasr/bin/train_llm.py b/funasr/bin/train_llm.py
index 3c93371..a33cd53 100644
--- a/funasr/bin/train_llm.py
+++ b/funasr/bin/train_llm.py
@@ -26,7 +26,7 @@
# from funasr.tokenizer.build_tokenizer import build_tokenizer
# from funasr.tokenizer.token_id_converter import TokenIDConverter
# from funasr.tokenizer.funtoken import build_tokenizer
-
+from funasr import AutoModel
@hydra.main(config_name=None, version_base=None)
def main_hydra(kwargs: DictConfig):
@@ -60,6 +60,16 @@
if use_ddp or use_fsdp:
dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
torch.cuda.set_device(local_rank)
+
+ device = kwargs.get("device", "cpu")
+ kwargs["device"] = "cpu"
+ model = AutoModel(**kwargs)
+ kwargs["device"] = device
+ model = model.model
+ tokenizer = kwargs["tokenizer"]
+ frontend = kwargs["frontend"]
+
+
# save config.yaml
if (use_ddp or use_fsdp) and dist.get_rank() == 0 or not (use_ddp or use_fsdp) and local_rank == 0:
@@ -68,28 +78,9 @@
OmegaConf.save(config=kwargs, f=yaml_file)
logging.info("config.yaml is saved to: %s", yaml_file)
- tokenizer = kwargs.get("tokenizer", None)
- if tokenizer is not None:
- tokenizer_class = tables.tokenizer_classes.get(tokenizer)
- tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
- kwargs["tokenizer"] = tokenizer
+
- # build frontend if frontend is none None
- frontend = kwargs.get("frontend", None)
- if frontend is not None:
- frontend_class = tables.frontend_classes.get(frontend)
- frontend = frontend_class(**kwargs["frontend_conf"])
- kwargs["frontend"] = frontend
- kwargs["input_size"] = frontend.output_size()
-
-
- # build model
- model_class = tables.model_classes.get(kwargs["model"])
- vocab_size = len(tokenizer.token_list) if hasattr(tokenizer, "token_list") else None
- vocab_size = len(tokenizer.get_vocab()) if hasattr(tokenizer, "get_vocab") else vocab_size
- model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
-
-
+
# init_param
init_param = kwargs.get("init_param", None)
--
Gitblit v1.9.1