From 1448e021accfdb03a381651cb5a8be6d1a6e8adf Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 19 二月 2024 14:59:26 +0800
Subject: [PATCH] aishell example

---
 funasr/bin/train.py |    8 +++++---
 1 files changed, 5 insertions(+), 3 deletions(-)

diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 8ea0c0d..c9a4a67 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -1,3 +1,6 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+
 import os
 import sys
 import torch
@@ -144,9 +147,8 @@
 
     # dataset
     dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
-    dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
-    dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer,
-                               **kwargs.get("dataset_conf"))
+    dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=True, **kwargs.get("dataset_conf"))
+    dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=False, **kwargs.get("dataset_conf"))
 
     # dataloader
     batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")

--
Gitblit v1.9.1