From 0a4a1d5257dace9561d95b38a9386539908dcd5e Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 23 四月 2024 12:48:52 +0800
Subject: [PATCH] Dev gzf exp (#1645)

---
 funasr/bin/train.py |    5 +++--
 1 files changed, 3 insertions(+), 2 deletions(-)

diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 4ab2d8a..ab49c82 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -90,7 +90,8 @@
     # freeze_param
     freeze_param = kwargs.get("freeze_param", None)
     if freeze_param is not None:
-        freeze_param = eval(freeze_param)
+        if "," in freeze_param:
+            freeze_param = eval(freeze_param)
         if isinstance(freeze_param, Sequence):
             freeze_param = (freeze_param,)
         logging.info("freeze_param is not None: %s", freeze_param)
@@ -104,7 +105,7 @@
     if use_ddp:
         model = model.cuda(local_rank)
         model = DDP(model, device_ids=[local_rank],
-                    find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", True))
+                    find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
     elif use_fsdp:
         # model = FSDP(model).cuda(local_rank)
 

--
Gitblit v1.9.1