1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
| #!/usr/bin/env python3
| import os
| from funasr.tasks.punctuation import PunctuationTask
|
|
| def parse_args():
| parser = PunctuationTask.get_parser()
| parser.add_argument(
| "--gpu_id",
| type=int,
| default=0,
| help="local gpu id.",
| )
| parser.add_argument(
| "--punc_list",
| type=str,
| default=None,
| help="Punctuation list",
| )
| args = parser.parse_args()
| return args
|
|
| def main(args=None, cmd=None):
| """
| punc training.
| """
| PunctuationTask.main(args=args, cmd=cmd)
|
|
| if __name__ == "__main__":
| args = parse_args()
|
| # setup local gpu_id
| os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
|
| # DDP settings
| if args.ngpu > 1:
| args.distributed = True
| else:
| args.distributed = False
|
| main(args=args)
|
|