-
Notifications
You must be signed in to change notification settings - Fork 4
/
main.py
111 lines (91 loc) · 3.37 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# -*- coding: utf-8 -*-
import numpy as np
import torch
import torch.distributed as dist
from parameters import get_args
from pcode.master import Master
import pcode.worker as worker
import pcode.utils.topology as topology
import pcode.utils.checkpoint as checkpoint
import pcode.utils.logging as logging
import pcode.utils.param_parser as param_parser
def main(conf):
# init the distributed world.
try:
dist.init_process_group("mpi")
except AttributeError as e:
print(f"failed to init the distributed world: {e}.")
conf.distributed = False
# init the config.
init_config(conf)
# start federated learning.
process = Master(conf) if conf.graph.rank == 0 else worker.get_worker_class(conf)
process.run()
def init_config(conf):
# define the graph for the computation.
conf.graph = topology.define_graph_topology(
world=conf.world,
world_conf=conf.world_conf,
n_participated=conf.n_participated,
on_cuda=conf.on_cuda,
)
conf.graph.rank = dist.get_rank()
# init related to randomness on cpu.
if not conf.same_seed_process:
conf.manual_seed = 1000 * conf.manual_seed + conf.graph.rank
conf.random_state = np.random.RandomState(conf.manual_seed)
torch.manual_seed(conf.manual_seed)
# configure cuda related.
if conf.graph.on_cuda:
assert torch.cuda.is_available()
torch.cuda.manual_seed(conf.manual_seed)
torch.cuda.set_device(conf.graph.device)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True if conf.train_fast else False
# init the model arch info.
conf.arch_info = (
param_parser.dict_parser(conf.complex_arch)
if conf.complex_arch is not None
else {"master": conf.arch, "worker": conf.arch}
)
conf.arch_info["worker"] = conf.arch_info["worker"].split(":")
# parse the fl_aggregate scheme.
conf._fl_aggregate = conf.fl_aggregate
conf.fl_aggregate = (
param_parser.dict_parser(conf.fl_aggregate)
if conf.fl_aggregate is not None
else {}
)
[setattr(conf, f"fl_aggregate_{k}", v) for k, v in conf.fl_aggregate.items()]
# parse personalization scheme.
conf._personalization_scheme = conf.personalization_scheme
conf.personalization_scheme = (
param_parser.dict_parser(conf.personalization_scheme)
if conf.personalization_scheme is not None
else {}
)
[
setattr(conf, f"personalization_scheme_{k}", v)
for k, v in conf.personalization_scheme.items()
]
# parse training data partition scheme.
conf._partition_data_conf = conf.partition_data_conf
conf.partition_data_conf = (
param_parser.dict_parser(conf.partition_data_conf)
if conf.partition_data_conf is not None
else {"distribution": "random"}
)
[setattr(conf, f"partition_{k}", v) for k, v in conf.partition_data_conf.items()]
# define checkpoint for logging (for federated learning server).
checkpoint.init_checkpoint(conf, rank=str(conf.graph.rank))
# configure logger.
conf.logger = logging.Logger(conf.checkpoint_dir)
# display the arguments' info.
if conf.graph.rank == 0:
logging.display_args(conf)
# sync the processes.
dist.barrier()
if __name__ == "__main__":
conf = get_args()
main(conf)