From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/optimizers/fairseq_adam.py |   22 +++++++++-------------
 1 files changed, 9 insertions(+), 13 deletions(-)

diff --git a/funasr/optimizers/fairseq_adam.py b/funasr/optimizers/fairseq_adam.py
index 9bdd0f8..32fae6a 100644
--- a/funasr/optimizers/fairseq_adam.py
+++ b/funasr/optimizers/fairseq_adam.py
@@ -38,13 +38,13 @@
     """
 
     def __init__(
-            self,
-            params,
-            lr=1e-3,
-            adam_betas=(0.9, 0.999),
-            adam_eps=1e-8,
-            weight_decay=0,
-            amsgrad=False,
+        self,
+        params,
+        lr=1e-3,
+        adam_betas=(0.9, 0.999),
+        adam_eps=1e-8,
+        weight_decay=0,
+        amsgrad=False,
     ):
         defaults = dict(
             lr=lr, betas=adam_betas, eps=adam_eps, weight_decay=weight_decay, amsgrad=amsgrad
@@ -104,9 +104,7 @@
                     state["exp_avg"] = state["exp_avg"].to(p_data_fp32)
                     state["exp_avg_sq"] = state["exp_avg_sq"].to(p_data_fp32)
                     if amsgrad:
-                        state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to(
-                            p_data_fp32
-                        )
+                        state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to(p_data_fp32)
 
                 exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                 if amsgrad:
@@ -131,9 +129,7 @@
                 step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
 
                 if group["weight_decay"] != 0:
-                    p_data_fp32.add_(
-                        p_data_fp32, alpha=-group["weight_decay"] * group["lr"]
-                    )
+                    p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"])
 
                 p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size)
 

--
Gitblit v1.9.1