From e9d2cfc3a134b00f4e98271fbee3838d1ccecbcc Mon Sep 17 00:00:00 2001
From: VirtuosoQ <2416050435@qq.com>
Date: 星期五, 26 四月 2024 14:59:30 +0800
Subject: [PATCH] FunASR java http  client

---
 funasr/bin/train.py |   30 +++++++++++++++++-------------
 1 files changed, 17 insertions(+), 13 deletions(-)

diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 0ff4ba1..c02a66f 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -55,6 +55,8 @@
     torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
     torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
     torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
+    # open tf32
+    torch.backends.cuda.matmul.allow_tf32 = kwargs.get("enable_tf32", True)
     
     local_rank = int(os.environ.get('LOCAL_RANK', 0))
     if local_rank == 0:
@@ -88,7 +90,8 @@
     # freeze_param
     freeze_param = kwargs.get("freeze_param", None)
     if freeze_param is not None:
-        freeze_param = eval(freeze_param)
+        if "," in freeze_param:
+            freeze_param = eval(freeze_param)
         if isinstance(freeze_param, Sequence):
             freeze_param = (freeze_param,)
         logging.info("freeze_param is not None: %s", freeze_param)
@@ -128,7 +131,8 @@
     else:
         model = model.to(device=kwargs.get("device", "cuda"))
 
-    logging.info(f"{model}")
+    if local_rank == 0:
+        logging.info(f"{model}")
     kwargs["device"] = next(model.parameters()).device
         
     # optim
@@ -149,8 +153,8 @@
     # dataset
     logging.info("Build dataloader")
     dataloader_class = tables.dataloader_classes.get(kwargs["dataset_conf"].get("dataloader", "DataloaderMapStyle"))
-    # dataloader = dataloader_class(**kwargs)
-    dataloader_tr, dataloader_val = dataloader_class(**kwargs)
+    dataloader = dataloader_class(**kwargs)
+    # dataloader_tr, dataloader_val = dataloader_class(**kwargs)
     trainer = Trainer(local_rank=local_rank,
                       use_ddp=use_ddp,
                       use_fsdp=use_fsdp,
@@ -172,15 +176,12 @@
     except:
         writer = None
 
-    # if use_ddp or use_fsdp:
-    #     context = Join([model])
-    # else:
-    context = nullcontext()
 
     for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
         time1 = time.perf_counter()
-        with context:
-            # dataloader_tr, dataloader_val = dataloader.build_iter(epoch)
+        
+        for data_split_i in range(dataloader.data_split_num):
+            dataloader_tr, dataloader_val = dataloader.build_iter(epoch, data_split_i=data_split_i)
             trainer.train_epoch(
                                 model=model,
                                 optim=optim,
@@ -189,15 +190,18 @@
                                 dataloader_train=dataloader_tr,
                                 dataloader_val=dataloader_val,
                                 epoch=epoch,
-                                writer=writer
+                                writer=writer,
+                                data_split_i=data_split_i,
+                                data_split_num=dataloader.data_split_num,
                                 )
-        scheduler.step()
+        
         trainer.validate_epoch(
             model=model,
             dataloader_val=dataloader_val,
             epoch=epoch,
             writer=writer
         )
+        scheduler.step()
 
         
         trainer.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler)
@@ -212,7 +216,7 @@
 
 
     if trainer.rank == 0:
-        average_checkpoints(trainer.output_dir, trainer.avg_nbest_model, trainer.val_acc_list)
+        average_checkpoints(trainer.output_dir, trainer.avg_nbest_model)
 
     trainer.close()
 

--
Gitblit v1.9.1