From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交

---
 funasr/models/data2vec/ema_module.py |   12 +++++-------
 1 files changed, 5 insertions(+), 7 deletions(-)

diff --git a/funasr/models/data2vec/ema_module.py b/funasr/models/data2vec/ema_module.py
index 4e46f50..3d7c21e 100644
--- a/funasr/models/data2vec/ema_module.py
+++ b/funasr/models/data2vec/ema_module.py
@@ -84,18 +84,14 @@
         decay = self.decay
 
         ema_state_dict = {}
-        ema_params = (
-            self.fp32_params if self.ema_fp32 else self.model.state_dict()
-        )
+        ema_params = self.fp32_params if self.ema_fp32 else self.model.state_dict()
         for key, param in new_model.state_dict().items():
             if isinstance(param, dict):
                 continue
             try:
                 ema_param = ema_params[key]
             except KeyError:
-                ema_param = (
-                    param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
-                )
+                ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
 
             if param.shape != ema_param.shape:
                 raise ValueError(
@@ -107,7 +103,9 @@
                 # Do not decay a model.version pytorch param
                 continue
 
-            if key in self.skip_keys or ("num_batches_tracked" in key and ema_param.dtype == torch.int64):
+            if key in self.skip_keys or (
+                "num_batches_tracked" in key and ema_param.dtype == torch.int64
+            ):
                 ema_param = param.to(dtype=ema_param.dtype).clone()
                 ema_params[key].copy_(ema_param)
             else:

--
Gitblit v1.9.1