| | |
| | | def model_summary(model: torch.nn.Module) -> str: |
| | | message = "Model structure:\n" |
| | | message += str(model) |
| | | # for p in model.parameters(): |
| | | # print(f"{p.numel()}") |
| | | tot_params = sum(p.numel() for p in model.parameters()) |
| | | num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| | | |
| | | tot_params, num_params = 0, 0 |
| | | for name, param in model.named_parameters(): |
| | | print( |
| | | "name: {}, dtype: {}, device: {}, trainable: {}, shape: {}, numel: {}".format( |
| | | name, param.dtype, param.device, param.requires_grad, param.shape, param.numel() |
| | | ) |
| | | ) |
| | | tot_params += param.numel() |
| | | if param.requires_grad: |
| | | num_params += param.numel() |
| | | |
| | | percent_trainable = "{:.1f}".format(num_params * 100.0 / tot_params) |
| | | tot_params = get_human_readable_count(tot_params) |
| | | num_params = get_human_readable_count(num_params) |