| | |
| | | state["exp_avg"] = state["exp_avg"].to(p_data_fp32) |
| | | state["exp_avg_sq"] = state["exp_avg_sq"].to(p_data_fp32) |
| | | if amsgrad: |
| | | state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to( |
| | | p_data_fp32 |
| | | ) |
| | | state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to(p_data_fp32) |
| | | |
| | | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] |
| | | if amsgrad: |
| | |
| | | step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 |
| | | |
| | | if group["weight_decay"] != 0: |
| | | p_data_fp32.add_( |
| | | p_data_fp32, alpha=-group["weight_decay"] * group["lr"] |
| | | ) |
| | | p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"]) |
| | | |
| | | p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size) |
| | | |