huangmingming
2023-01-30 adcee8828ef5d78b575043954deb662a35e318f7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import os
 
from modelscope.metainfo import Trainers
from modelscope.trainers import build_trainer
 
from funasr.datasets.ms_dataset import MsDataset
from funasr.utils.modelscope_param import modelscope_args
 
 
def modelscope_finetune(params):
    if not os.path.exists(params.output_dir):
        os.makedirs(params.output_dir, exist_ok=True)
    # dataset split ["train", "validation"]
    ds_dict = MsDataset.load(params.data_path)
    kwargs = dict(
        model=params.model,
        data_dir=ds_dict,
        dataset_type=params.dataset_type,
        work_dir=params.output_dir,
        batch_bins=params.batch_bins,
        max_epoch=params.max_epoch,
        lr=params.lr)
    trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
    trainer.train()
 
 
if __name__ == '__main__':
    params = modelscope_args(model="damo/speech_data2vec_pretrain-zh-cn-aishell2-16k-pytorch",
                             data_path="./data")
    params.output_dir = "./checkpoint"
    params.data_path = "./example_data/"
    params.dataset_type = "small"
    params.batch_bins = 16000
    params.max_epoch = 50
    params.lr = 0.00005
 
    modelscope_finetune(params)