From 3a4281f4959534b1bf5d01acf0085f4f8e6f2ec8 Mon Sep 17 00:00:00 2001
From: wuhongsheng <664116298@qq.com>
Date: 星期五, 05 七月 2024 00:55:32 +0800
Subject: [PATCH] 优化speakid和语句匹配逻辑,部分解决speakid不从0递增问题 (#1870)
---
funasr/models/data2vec/ema_module.py | 12 +++++-------
1 files changed, 5 insertions(+), 7 deletions(-)
diff --git a/funasr/models/data2vec/ema_module.py b/funasr/models/data2vec/ema_module.py
index 4e46f50..3d7c21e 100644
--- a/funasr/models/data2vec/ema_module.py
+++ b/funasr/models/data2vec/ema_module.py
@@ -84,18 +84,14 @@
decay = self.decay
ema_state_dict = {}
- ema_params = (
- self.fp32_params if self.ema_fp32 else self.model.state_dict()
- )
+ ema_params = self.fp32_params if self.ema_fp32 else self.model.state_dict()
for key, param in new_model.state_dict().items():
if isinstance(param, dict):
continue
try:
ema_param = ema_params[key]
except KeyError:
- ema_param = (
- param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
- )
+ ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
if param.shape != ema_param.shape:
raise ValueError(
@@ -107,7 +103,9 @@
# Do not decay a model.version pytorch param
continue
- if key in self.skip_keys or ("num_batches_tracked" in key and ema_param.dtype == torch.int64):
+ if key in self.skip_keys or (
+ "num_batches_tracked" in key and ema_param.dtype == torch.int64
+ ):
ema_param = param.to(dtype=ema_param.dtype).clone()
ema_params[key].copy_(ema_param)
else:
--
Gitblit v1.9.1