zhifu gao
2024-04-24 861147c7308b91068ffa02724fdf74ee623a909e
funasr/optimizers/fairseq_adam.py
@@ -104,9 +104,7 @@
                    state["exp_avg"] = state["exp_avg"].to(p_data_fp32)
                    state["exp_avg_sq"] = state["exp_avg_sq"].to(p_data_fp32)
                    if amsgrad:
                        state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to(
                            p_data_fp32
                        )
                        state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to(p_data_fp32)
                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                if amsgrad:
@@ -131,9 +129,7 @@
                step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
                if group["weight_decay"] != 0:
                    p_data_fp32.add_(
                        p_data_fp32, alpha=-group["weight_decay"] * group["lr"]
                    )
                    p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"])
                p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size)