speech_asr
2023-04-19 680cdb55bbde415c2f750e58808faedc6d1a6bf3
update
2个文件已修改
2个文件已添加
81 ■■■■ 已修改文件
funasr/bin/train.py 22 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/build_model.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/build_optimizer.py 26 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/build_scheduler.py 31 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/train.py
@@ -9,6 +9,8 @@
from funasr.utils.build_dataloader import build_dataloader
from funasr.utils.build_distributed import build_distributed
from funasr.utils.prepare_data import prepare_data
from funasr.utils.build_optimizer import build_optimizer
from funasr.utils.build_scheduler import build_scheduler
from funasr.utils.types import str2bool
@@ -355,20 +357,6 @@
                                                                   distributed_option.dist_rank,
                                                                   distributed_option.local_rank))
    # optimizers = cls.build_optimizers(args, model=model)
    # schedulers = []
    # for i, optim in enumerate(optimizers, 1):
    #     suf = "" if i == 1 else str(i)
    #     name = getattr(args, f"scheduler{suf}")
    #     conf = getattr(args, f"scheduler{suf}_conf")
    #     if name is not None:
    #         cls_ = scheduler_classes.get(name)
    #         if cls_ is None:
    #             raise ValueError(
    #                 f"must be one of {list(scheduler_classes)}: {name}"
    #             )
    #         scheduler = cls_(optim, **conf)
    #     else:
    #         scheduler = None
    #
    #     schedulers.append(scheduler)
    model = build_model(args)
    optimizers = build_optimizer(args, model=model)
    schedule = build_scheduler(args)
funasr/utils/build_model.py
@@ -2,7 +2,7 @@
def build_model(args):
    if args.token_list is not None:
        with open(args.token_list, encoding="utf-8") as f:
        with open(args.token_list) as f:
            token_list = [line.rstrip() for line in f]
            args.token_list = list(token_list)
            vocab_size = len(token_list)
funasr/utils/build_optimizer.py
New file
@@ -0,0 +1,26 @@
import torch
from funasr.optimizers.fairseq_adam import FairseqAdam
from funasr.optimizers.sgd import SGD
def build_optimizer(args, model):
    optim_classes = dict(
        adam=torch.optim.Adam,
        fairseq_adam=FairseqAdam,
        adamw=torch.optim.AdamW,
        sgd=SGD,
        adadelta=torch.optim.Adadelta,
        adagrad=torch.optim.Adagrad,
        adamax=torch.optim.Adamax,
        asgd=torch.optim.ASGD,
        lbfgs=torch.optim.LBFGS,
        rmsprop=torch.optim.RMSprop,
        rprop=torch.optim.Rprop,
    )
    optim_class = optim_classes.get(args.optim)
    if optim_class is None:
        raise ValueError(f"must be one of {list(optim_classes)}: {args.optim}")
    optimizer = optim_class(model.parameters(), **args.optim_conf)
    return optimizer
funasr/utils/build_scheduler.py
New file
@@ -0,0 +1,31 @@
import torch
import torch.multiprocessing
import torch.nn
import torch.optim
from funasr.schedulers.noam_lr import NoamLR
from funasr.schedulers.tri_stage_scheduler import TriStageLR
from funasr.schedulers.warmup_lr import WarmupLR
def build_scheduler(args, optimizer):
    scheduler_classes = dict(
        ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
        lambdalr=torch.optim.lr_scheduler.LambdaLR,
        steplr=torch.optim.lr_scheduler.StepLR,
        multisteplr=torch.optim.lr_scheduler.MultiStepLR,
        exponentiallr=torch.optim.lr_scheduler.ExponentialLR,
        CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
        noamlr=NoamLR,
        warmuplr=WarmupLR,
        tri_stage=TriStageLR,
        cycliclr=torch.optim.lr_scheduler.CyclicLR,
        onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
        CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
    )
    scheduler_class = scheduler_classes.get(args.scheduler)
    if scheduler_class is None:
        raise ValueError(f"must be one of {list(scheduler_classes)}: {args.scheduler}")
    scheduler = scheduler_class(optimizer, **args.scheduler_conf)
    return scheduler