| | |
| | | def main_worker(cls, args: argparse.Namespace): |
| | | assert check_argument_types() |
| | | |
| | | args.ngpu = 0 |
| | | # 0. Init distributed process |
| | | distributed_option = build_dataclass(DistributedOption, args) |
| | | # Setting distributed_option.dist_rank, etc. |
| | |
| | | raise RuntimeError( |
| | | f"model must inherit {FunASRModel.__name__}, but got {type(model)}" |
| | | ) |
| | | #model = model.to( |
| | | # dtype=getattr(torch, args.train_dtype), |
| | | # device="cuda" if args.ngpu > 0 else "cpu", |
| | | #) |
| | | model = model.to( |
| | | dtype=getattr(torch, args.train_dtype), |
| | | device="cpu", |
| | | device="cuda" if args.ngpu > 0 else "cpu", |
| | | ) |
| | | for t in args.freeze_param: |
| | | for k, p in model.named_parameters(): |