| | |
| | | decay = self.decay |
| | | |
| | | ema_state_dict = {} |
| | | ema_params = ( |
| | | self.fp32_params if self.ema_fp32 else self.model.state_dict() |
| | | ) |
| | | ema_params = self.fp32_params if self.ema_fp32 else self.model.state_dict() |
| | | for key, param in new_model.state_dict().items(): |
| | | if isinstance(param, dict): |
| | | continue |
| | | try: |
| | | ema_param = ema_params[key] |
| | | except KeyError: |
| | | ema_param = ( |
| | | param.float().clone() if param.ndim == 1 else copy.deepcopy(param) |
| | | ) |
| | | ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param) |
| | | |
| | | if param.shape != ema_param.shape: |
| | | raise ValueError( |
| | |
| | | # Do not decay a model.version pytorch param |
| | | continue |
| | | |
| | | if key in self.skip_keys or ("num_batches_tracked" in key and ema_param.dtype == torch.int64): |
| | | if key in self.skip_keys or ( |
| | | "num_batches_tracked" in key and ema_param.dtype == torch.int64 |
| | | ): |
| | | ema_param = param.to(dtype=ema_param.dtype).clone() |
| | | ema_params[key].copy_(ema_param) |
| | | else: |