From 94de39dde2e616a01683c518023d0fab72b4e103 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 19 二月 2024 22:21:50 +0800
Subject: [PATCH] aishell example
---
funasr/bin/train.py | 15 ++++++++-------
1 files changed, 8 insertions(+), 7 deletions(-)
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index d9d4d62..d916509 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -1,3 +1,6 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+
import os
import sys
import torch
@@ -76,9 +79,8 @@
frontend = frontend_class(**kwargs["frontend_conf"])
kwargs["frontend"] = frontend
kwargs["input_size"] = frontend.output_size()
-
- # import pdb;
- # pdb.set_trace()
+
+
# build model
model_class = tables.model_classes.get(kwargs["model"])
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
@@ -144,9 +146,8 @@
# dataset
dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
- dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
- dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer,
- **kwargs.get("dataset_conf"))
+ dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=True, **kwargs.get("dataset_conf"))
+ dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=False, **kwargs.get("dataset_conf"))
# dataloader
batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
@@ -154,7 +155,7 @@
if batch_sampler is not None:
batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
- batch_sampler_val = batch_sampler_class(dataset_tr, is_training=False, **kwargs.get("dataset_conf"))
+ batch_sampler_val = batch_sampler_class(dataset_val, is_training=False, **kwargs.get("dataset_conf"))
dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
collate_fn=dataset_tr.collator,
batch_sampler=batch_sampler,
--
Gitblit v1.9.1