import os
|
|
from modelscope.metainfo import Trainers
|
from modelscope.trainers import build_trainer
|
|
from funasr.datasets.ms_dataset import MsDataset
|
from funasr.utils.modelscope_param import modelscope_args
|
|
|
def modelscope_finetune(params):
|
if not os.path.exists(params.output_dir):
|
os.makedirs(params.output_dir, exist_ok=True)
|
# dataset split ["train", "validation"]
|
ds_dict = MsDataset.load(params.data_path)
|
kwargs = dict(
|
model=params.model,
|
data_dir=ds_dict,
|
dataset_type=params.dataset_type,
|
work_dir=params.output_dir,
|
batch_bins=params.batch_bins,
|
max_epoch=params.max_epoch,
|
lr=params.lr)
|
trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
|
trainer.train()
|
|
|
if __name__ == '__main__':
|
params = modelscope_args(model="damo/speech_data2vec_pretrain-zh-cn-aishell2-16k-pytorch",
|
data_path="./data")
|
params.output_dir = "./checkpoint"
|
params.data_path = "./example_data/"
|
params.dataset_type = "small"
|
params.batch_bins = 16000
|
params.max_epoch = 50
|
params.lr = 0.00005
|
|
modelscope_finetune(params)
|