游雁
2023-11-17 244c033fbaeae15faf8b0351365bdb7607b2e2bb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
 
 
class SGD(torch.optim.SGD):
    """Thin inheritance of torch.optim.SGD to bind the required arguments, 'lr'
 
    Note that
    the arguments of the optimizer invoked by AbsTask.main()
    must have default value except for 'param'.
 
    I can't understand why only SGD.lr doesn't have the default value.
    """
 
    def __init__(
        self,
        params,
        lr: float = 0.1,
        momentum: float = 0.0,
        dampening: float = 0.0,
        weight_decay: float = 0.0,
        nesterov: bool = False,
    ):
        super().__init__(
            params,
            lr=lr,
            momentum=momentum,
            dampening=dampening,
            weight_decay=weight_decay,
            nesterov=nesterov,
        )