From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/optimizers/fairseq_adam.py | 22 +++++++++-------------
1 files changed, 9 insertions(+), 13 deletions(-)
diff --git a/funasr/optimizers/fairseq_adam.py b/funasr/optimizers/fairseq_adam.py
index 9bdd0f8..32fae6a 100644
--- a/funasr/optimizers/fairseq_adam.py
+++ b/funasr/optimizers/fairseq_adam.py
@@ -38,13 +38,13 @@
"""
def __init__(
- self,
- params,
- lr=1e-3,
- adam_betas=(0.9, 0.999),
- adam_eps=1e-8,
- weight_decay=0,
- amsgrad=False,
+ self,
+ params,
+ lr=1e-3,
+ adam_betas=(0.9, 0.999),
+ adam_eps=1e-8,
+ weight_decay=0,
+ amsgrad=False,
):
defaults = dict(
lr=lr, betas=adam_betas, eps=adam_eps, weight_decay=weight_decay, amsgrad=amsgrad
@@ -104,9 +104,7 @@
state["exp_avg"] = state["exp_avg"].to(p_data_fp32)
state["exp_avg_sq"] = state["exp_avg_sq"].to(p_data_fp32)
if amsgrad:
- state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to(
- p_data_fp32
- )
+ state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to(p_data_fp32)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
if amsgrad:
@@ -131,9 +129,7 @@
step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
if group["weight_decay"] != 0:
- p_data_fp32.add_(
- p_data_fp32, alpha=-group["weight_decay"] * group["lr"]
- )
+ p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"])
p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size)
--
Gitblit v1.9.1