From 00ea1186f96e6732e2edb4fab6c0ed6896e3b352 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 19 十二月 2023 22:53:18 +0800
Subject: [PATCH] funasr2

---
 funasr/bin/train.py |    3 ++-
 1 files changed, 2 insertions(+), 1 deletions(-)

diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 72fa9fa..8112002 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -39,7 +39,7 @@
 	# preprocess_config(kwargs)
 	# import pdb; pdb.set_trace()
 	# set random seed
-	registry_tables.print_register_tables()
+	registry_tables.print()
 	set_all_random_seed(kwargs.get("seed", 0))
 	torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
 	torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
@@ -72,6 +72,7 @@
 		frontend_class = registry_tables.frontend_classes.get(frontend.lower())
 		frontend = frontend_class(**kwargs["frontend_conf"])
 		kwargs["frontend"] = frontend
+		kwargs["input_size"] = frontend.output_size()
 	
 	# import pdb;
 	# pdb.set_trace()

--
Gitblit v1.9.1