From 94de39dde2e616a01683c518023d0fab72b4e103 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 19 二月 2024 22:21:50 +0800
Subject: [PATCH] aishell example

---
 funasr/bin/train.py |   13 +++++++------
 1 files changed, 7 insertions(+), 6 deletions(-)

diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 8ea0c0d..d916509 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -1,3 +1,6 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+
 import os
 import sys
 import torch
@@ -76,9 +79,8 @@
         frontend = frontend_class(**kwargs["frontend_conf"])
         kwargs["frontend"] = frontend
         kwargs["input_size"] = frontend.output_size()
-    
-    # import pdb;
-    # pdb.set_trace()
+
+
     # build model
     model_class = tables.model_classes.get(kwargs["model"])
     model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
@@ -144,9 +146,8 @@
 
     # dataset
     dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
-    dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
-    dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer,
-                               **kwargs.get("dataset_conf"))
+    dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=True, **kwargs.get("dataset_conf"))
+    dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=False, **kwargs.get("dataset_conf"))
 
     # dataloader
     batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")

--
Gitblit v1.9.1