kongdeqiang
2026-03-13 28ccfbfc51068a663a80764e14074df5edf2b5ba
funasr/models/data2vec/ema_module.py
@@ -84,18 +84,14 @@
        decay = self.decay
        ema_state_dict = {}
        ema_params = (
            self.fp32_params if self.ema_fp32 else self.model.state_dict()
        )
        ema_params = self.fp32_params if self.ema_fp32 else self.model.state_dict()
        for key, param in new_model.state_dict().items():
            if isinstance(param, dict):
                continue
            try:
                ema_param = ema_params[key]
            except KeyError:
                ema_param = (
                    param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
                )
                ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
            if param.shape != ema_param.shape:
                raise ValueError(
@@ -107,7 +103,9 @@
                # Do not decay a model.version pytorch param
                continue
            if key in self.skip_keys or ("num_batches_tracked" in key and ema_param.dtype == torch.int64):
            if key in self.skip_keys or (
                "num_batches_tracked" in key and ema_param.dtype == torch.int64
            ):
                ema_param = param.to(dtype=ema_param.dtype).clone()
                ema_params[key].copy_(ema_param)
            else: