From 0efc87352ce7d3903dbdedbfa5d01ca5e1cb19e7 Mon Sep 17 00:00:00 2001
From: Shi Xian <40013335+R1ckShi@users.noreply.github.com>
Date: 星期四, 05 十二月 2024 15:15:38 +0800
Subject: [PATCH] Merge pull request #2267 from modelscope/dev_sx2
---
funasr/models/data2vec/ema_module.py | 12 +++++-------
1 files changed, 5 insertions(+), 7 deletions(-)
diff --git a/funasr/models/data2vec/ema_module.py b/funasr/models/data2vec/ema_module.py
index 4e46f50..3d7c21e 100644
--- a/funasr/models/data2vec/ema_module.py
+++ b/funasr/models/data2vec/ema_module.py
@@ -84,18 +84,14 @@
decay = self.decay
ema_state_dict = {}
- ema_params = (
- self.fp32_params if self.ema_fp32 else self.model.state_dict()
- )
+ ema_params = self.fp32_params if self.ema_fp32 else self.model.state_dict()
for key, param in new_model.state_dict().items():
if isinstance(param, dict):
continue
try:
ema_param = ema_params[key]
except KeyError:
- ema_param = (
- param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
- )
+ ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
if param.shape != ema_param.shape:
raise ValueError(
@@ -107,7 +103,9 @@
# Do not decay a model.version pytorch param
continue
- if key in self.skip_keys or ("num_batches_tracked" in key and ema_param.dtype == torch.int64):
+ if key in self.skip_keys or (
+ "num_batches_tracked" in key and ema_param.dtype == torch.int64
+ ):
ema_param = param.to(dtype=ema_param.dtype).clone()
ema_params[key].copy_(ema_param)
else:
--
Gitblit v1.9.1