From e9d2cfc3a134b00f4e98271fbee3838d1ccecbcc Mon Sep 17 00:00:00 2001
From: VirtuosoQ <2416050435@qq.com>
Date: 星期五, 26 四月 2024 14:59:30 +0800
Subject: [PATCH] FunASR java http  client

---
 funasr/bin/train.py |   34 ++++++++++++++++++----------------
 1 files changed, 18 insertions(+), 16 deletions(-)

diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 880bb63..c02a66f 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -55,6 +55,8 @@
     torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
     torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
     torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
+    # open tf32
+    torch.backends.cuda.matmul.allow_tf32 = kwargs.get("enable_tf32", True)
     
     local_rank = int(os.environ.get('LOCAL_RANK', 0))
     if local_rank == 0:
@@ -88,7 +90,8 @@
     # freeze_param
     freeze_param = kwargs.get("freeze_param", None)
     if freeze_param is not None:
-        freeze_param = eval(freeze_param)
+        if "," in freeze_param:
+            freeze_param = eval(freeze_param)
         if isinstance(freeze_param, Sequence):
             freeze_param = (freeze_param,)
         logging.info("freeze_param is not None: %s", freeze_param)
@@ -173,15 +176,12 @@
     except:
         writer = None
 
-    # if use_ddp or use_fsdp:
-    #     context = Join([model])
-    # else:
-    #     context = nullcontext()
-    context = nullcontext()
+
     for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
         time1 = time.perf_counter()
-        with context:
-            dataloader_tr, dataloader_val = dataloader.build_iter(epoch)
+        
+        for data_split_i in range(dataloader.data_split_num):
+            dataloader_tr, dataloader_val = dataloader.build_iter(epoch, data_split_i=data_split_i)
             trainer.train_epoch(
                                 model=model,
                                 optim=optim,
@@ -190,15 +190,17 @@
                                 dataloader_train=dataloader_tr,
                                 dataloader_val=dataloader_val,
                                 epoch=epoch,
-                                writer=writer
+                                writer=writer,
+                                data_split_i=data_split_i,
+                                data_split_num=dataloader.data_split_num,
                                 )
-        with context:
-            trainer.validate_epoch(
-                model=model,
-                dataloader_val=dataloader_val,
-                epoch=epoch,
-                writer=writer
-            )
+        
+        trainer.validate_epoch(
+            model=model,
+            dataloader_val=dataloader_val,
+            epoch=epoch,
+            writer=writer
+        )
         scheduler.step()
 
         

--
Gitblit v1.9.1