| | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | params, |
| | | lr=1e-3, |
| | | adam_betas=(0.9, 0.999), |
| | | adam_eps=1e-8, |
| | | weight_decay=0, |
| | | amsgrad=False, |
| | | self, |
| | | params, |
| | | lr=1e-3, |
| | | adam_betas=(0.9, 0.999), |
| | | adam_eps=1e-8, |
| | | weight_decay=0, |
| | | amsgrad=False, |
| | | ): |
| | | defaults = dict( |
| | | lr=lr, betas=adam_betas, eps=adam_eps, weight_decay=weight_decay, amsgrad=amsgrad |
| | |
| | | state["exp_avg"] = state["exp_avg"].to(p_data_fp32) |
| | | state["exp_avg_sq"] = state["exp_avg_sq"].to(p_data_fp32) |
| | | if amsgrad: |
| | | state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to( |
| | | p_data_fp32 |
| | | ) |
| | | state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to(p_data_fp32) |
| | | |
| | | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] |
| | | if amsgrad: |
| | |
| | | step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 |
| | | |
| | | if group["weight_decay"] != 0: |
| | | p_data_fp32.add_( |
| | | p_data_fp32, alpha=-group["weight_decay"] * group["lr"] |
| | | ) |
| | | p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"]) |
| | | |
| | | p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size) |
| | | |