From 1448e021accfdb03a381651cb5a8be6d1a6e8adf Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 19 二月 2024 14:59:26 +0800
Subject: [PATCH] aishell example
---
funasr/bin/train.py | 8 +++++---
1 files changed, 5 insertions(+), 3 deletions(-)
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 8ea0c0d..c9a4a67 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -1,3 +1,6 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+
import os
import sys
import torch
@@ -144,9 +147,8 @@
# dataset
dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
- dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
- dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer,
- **kwargs.get("dataset_conf"))
+ dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=True, **kwargs.get("dataset_conf"))
+ dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=False, **kwargs.get("dataset_conf"))
# dataloader
batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
--
Gitblit v1.9.1