kongdeqiang
5 天以前 28ccfbfc51068a663a80764e14074df5edf2b5ba
funasr/optimizers/fairseq_adam.py
@@ -38,13 +38,13 @@
    """
    def __init__(
            self,
            params,
            lr=1e-3,
            adam_betas=(0.9, 0.999),
            adam_eps=1e-8,
            weight_decay=0,
            amsgrad=False,
        self,
        params,
        lr=1e-3,
        adam_betas=(0.9, 0.999),
        adam_eps=1e-8,
        weight_decay=0,
        amsgrad=False,
    ):
        defaults = dict(
            lr=lr, betas=adam_betas, eps=adam_eps, weight_decay=weight_decay, amsgrad=amsgrad
@@ -104,9 +104,7 @@
                    state["exp_avg"] = state["exp_avg"].to(p_data_fp32)
                    state["exp_avg_sq"] = state["exp_avg_sq"].to(p_data_fp32)
                    if amsgrad:
                        state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to(
                            p_data_fp32
                        )
                        state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to(p_data_fp32)
                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                if amsgrad:
@@ -131,9 +129,7 @@
                step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
                if group["weight_decay"] != 0:
                    p_data_fp32.add_(
                        p_data_fp32, alpha=-group["weight_decay"] * group["lr"]
                    )
                    p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"])
                p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size)