From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交
---
funasr/models/data2vec/ema_module.py | 12 +++++-------
1 files changed, 5 insertions(+), 7 deletions(-)
diff --git a/funasr/models/data2vec/ema_module.py b/funasr/models/data2vec/ema_module.py
index 4e46f50..3d7c21e 100644
--- a/funasr/models/data2vec/ema_module.py
+++ b/funasr/models/data2vec/ema_module.py
@@ -84,18 +84,14 @@
decay = self.decay
ema_state_dict = {}
- ema_params = (
- self.fp32_params if self.ema_fp32 else self.model.state_dict()
- )
+ ema_params = self.fp32_params if self.ema_fp32 else self.model.state_dict()
for key, param in new_model.state_dict().items():
if isinstance(param, dict):
continue
try:
ema_param = ema_params[key]
except KeyError:
- ema_param = (
- param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
- )
+ ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
if param.shape != ema_param.shape:
raise ValueError(
@@ -107,7 +103,9 @@
# Do not decay a model.version pytorch param
continue
- if key in self.skip_keys or ("num_batches_tracked" in key and ema_param.dtype == torch.int64):
+ if key in self.skip_keys or (
+ "num_batches_tracked" in key and ema_param.dtype == torch.int64
+ ):
ema_param = param.to(dtype=ema_param.dtype).clone()
ema_params[key].copy_(ema_param)
else:
--
Gitblit v1.9.1