From 557b913c5d78945e49cb2ac2bf254a2de40b6cd5 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 14 三月 2024 11:48:31 +0800
Subject: [PATCH] v1.0.16

---
 funasr/bin/train_llm.py |   35 +++++++++++++----------------------
 1 files changed, 13 insertions(+), 22 deletions(-)

diff --git a/funasr/bin/train_llm.py b/funasr/bin/train_llm.py
index 3c93371..a33cd53 100644
--- a/funasr/bin/train_llm.py
+++ b/funasr/bin/train_llm.py
@@ -26,7 +26,7 @@
 # from funasr.tokenizer.build_tokenizer import build_tokenizer
 # from funasr.tokenizer.token_id_converter import TokenIDConverter
 # from funasr.tokenizer.funtoken import build_tokenizer
-
+from funasr import AutoModel
 
 @hydra.main(config_name=None, version_base=None)
 def main_hydra(kwargs: DictConfig):
@@ -60,6 +60,16 @@
     if use_ddp or use_fsdp:
         dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
         torch.cuda.set_device(local_rank)
+        
+    device = kwargs.get("device", "cpu")
+    kwargs["device"] = "cpu"
+    model = AutoModel(**kwargs)
+    kwargs["device"] = device
+    model = model.model
+    tokenizer = kwargs["tokenizer"]
+    frontend = kwargs["frontend"]
+    
+    
     
     # save config.yaml
     if (use_ddp or use_fsdp) and dist.get_rank() == 0 or not (use_ddp or use_fsdp) and local_rank == 0:
@@ -68,28 +78,9 @@
         OmegaConf.save(config=kwargs, f=yaml_file)
         logging.info("config.yaml is saved to: %s", yaml_file)
 
-    tokenizer = kwargs.get("tokenizer", None)
-    if tokenizer is not None:
-        tokenizer_class = tables.tokenizer_classes.get(tokenizer)
-        tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
-        kwargs["tokenizer"] = tokenizer
+
     
-    # build frontend if frontend is none None
-    frontend = kwargs.get("frontend", None)
-    if frontend is not None:
-        frontend_class = tables.frontend_classes.get(frontend)
-        frontend = frontend_class(**kwargs["frontend_conf"])
-        kwargs["frontend"] = frontend
-        kwargs["input_size"] = frontend.output_size()
-
-
-    # build model
-    model_class = tables.model_classes.get(kwargs["model"])
-    vocab_size = len(tokenizer.token_list) if hasattr(tokenizer, "token_list") else None
-    vocab_size = len(tokenizer.get_vocab()) if hasattr(tokenizer, "get_vocab") else vocab_size
-    model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
-
-
+    
 
     # init_param
     init_param = kwargs.get("init_param", None)

--
Gitblit v1.9.1