kongdeqiang
6 天以前 28ccfbfc51068a663a80764e14074df5edf2b5ba
funasr/train_utils/initialize.py
@@ -19,7 +19,6 @@
        init: Method of initialization.
    """
    # weight init
    for p in model.parameters():
        if p.dim() > 1:
@@ -40,9 +39,7 @@
    # reset some modules with default init
    for m in model.modules():
        if isinstance(
            m, (torch.nn.Embedding, torch.nn.LayerNorm, torch.nn.GroupNorm)
        ):
        if isinstance(m, (torch.nn.Embedding, torch.nn.LayerNorm, torch.nn.GroupNorm)):
            m.reset_parameters()
        if hasattr(m, "espnet_initialization_fn"):
            m.espnet_initialization_fn()
@@ -56,4 +53,3 @@
        model.frontend, "reload_pretrained_parameters", None
    ):
        model.frontend.reload_pretrained_parameters()