zhifu gao
2024-04-25 fc68b5ffe453235294a561737d8e84bb6c1689a4
funasr/train_utils/model_summary.py
@@ -47,6 +47,8 @@
def model_summary(model: torch.nn.Module) -> str:
    message = "Model structure:\n"
    message += str(model)
    # for p in model.parameters():
    #     print(f"{p.numel()}")
    tot_params = sum(p.numel() for p in model.parameters())
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    percent_trainable = "{:.1f}".format(num_params * 100.0 / tot_params)