forked from JianhuanZeng/learning-to-collaborate
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
138 lines (114 loc) · 6.49 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# main.py
import copy
import pickle
import numpy as np
import pandas as pd
import torch
from utils.utils_data import get_data
from utils.utils_func import construct_log, get_random_dir_name, setup_seed
from hyper_model.train import Training_all
import os
import pdb
import argparse
import pickle
parser = argparse.ArgumentParser()
# federated arguments
parser.add_argument('--num_users', type=int, default=5, help="number of users: K")
parser.add_argument('--shard_per_user', type=int, default=2, help="classes per user (each user has the num of classes)")
parser.add_argument('--target_usr', type=int, default=0, help="target usr id")
# training arguments
parser.add_argument('--epochs_per_valid', type=int, default=50, help="rounds of valid")
parser.add_argument('--total_hnet_epoch', type=int, default=5, help="hnet update innner steps")
parser.add_argument('--total_ray_epoch', type=int, default=1, help="hnet update innner steps")
parser.add_argument('--total_epoch', type=int, default=2000, help="hnet update innner steps")
parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
parser.add_argument('--local_bs', type=int, default=512, help="local batch size: B")
parser.add_argument('--lr', type=float, default=0.01, help="learning rate")
parser.add_argument('--lr_prefer', type=float, default=0.01, help="learning rate for preference vector")
parser.add_argument('--alpha', type=float, default=0.2, help="alpha for sampling the preference vector")
parser.add_argument('--momentum', type=float, default=0.5, help="SGD momentum (default: 0.5)")
parser.add_argument('--num_workers', type=int, default=0, help="the number of workers for the dataloader.")
parser.add_argument('--eps_prefer', type=float, default=0.1, help="learning rate for preference vector")
parser.add_argument('--sigma', type=float, default=0.1, help="learning rate for preference vector")
parser.add_argument('--std', type=float, default=0.1, help="learning rate for preference vector")
parser.add_argument('--trainN', type=int, default=2000, help="the number of generated train samples .")
parser.add_argument('--testN', type=int, default=1000, help="the number of generated test samples.")
parser.add_argument('--solver_type', type=str, default="epo", help="the type of solving the model")
parser.add_argument('--sample_ray', action='store_true', help='whether sampling alpha for learning Pareto Front')
parser.add_argument('--train_baseline', action='store_true', help='whether train baseline for eicu dataset')
parser.add_argument('--baseline_type', type=str, default="fedave", help="the type of training baseline (fedave, local)")
# model structure
parser.add_argument('--n_classes', type=int, default=10, help="the number of classes.")
parser.add_argument('--entropy_weight', type=float, default=0.0, help="the number of classes.")
parser.add_argument('--n_hidden', type=int, default=2, help="hidden layer for the hypernet.")
parser.add_argument('--embedding_dim', type=int, default=5, help="embedding dim for eicu embedding the categorical features")
parser.add_argument('--input_dim', type=int, default=20, help="input dim (generate dim) for the hypernet.")
parser.add_argument('--output_dim', type=int, default=2, help="hidden layer for the hypernet.")
parser.add_argument('--hidden_dim', type=int, default=100, help="hidden dim for the hypernet.")
parser.add_argument('--spec_norm', action='store_true', help='whether using spectral norm not')
# learning setup arguments
parser.add_argument('--iid', action='store_true', help='whether i.i.d or not')
parser.add_argument('--auto_deploy', action='store_true', help='whether auto deploy not')
# devices
parser.add_argument('--gpus', type=str, default="1", help='gpus for training')
# dataset/log/outputs/ dir
parser.add_argument('--dataset', type=str, default='cifar10', help="name of dataset")
parser.add_argument('--data_root', type=str, default='data', help="name of dataset")
parser.add_argument('--outputs_root', type=str, default='outputs', help="name of dataset")
parser.add_argument('--target_dir', type=str, default='', help=" dir name of for saving all generating data")
args = parser.parse_args()
if __name__ == '__main__':
if args.target_dir == "":
args.log_dir = os.path.join(args.outputs_root, get_random_dir_name())
else:
args.log_dir = os.path.join(args.outputs_root, args.target_dir)
setup_seed(seed = args.seed)
# prepare for learning
initial_device = torch.device('cuda:{}'.format(args.gpus[0]) if torch.cuda.is_available() and args.gpus != -1 else 'cpu')
args.hnet_model_dir = os.path.join(args.log_dir, "hnet_model_saved")
args.local_hnet_model_dir = os.path.join(args.log_dir, "local_hnet_model_saved")
args.local_tnet_model_dir = os.path.join(args.log_dir, "local_tnet_model_saved")
args.eps_prefer = 1.0/(3*args.num_users)
logger = construct_log(args)
if args.dataset == "adult":
args.input_dim = 99
args.output_dim = 2
args.num_users = 2
args.local_bs = -1
elif args.dataset == "synthetic":
args.output_dim = 1
args.num_users = 6
args.local_bs = -1
elif args.dataset == "cifar10":
args.local_bs = 512
args.num_users = 10
elif args.dataset == "danger_detection":
args.num_users = 2
args.local_bs = -1
if args.train_baseline and args.baseline_type == "local":
users_used = [args.target_usr]
else:
users_used = [i for i in range(args.num_users)]
dataset_train, dataset_test, dict_users_train, dict_users_test = get_data(args)
# print("*"*100)
# print(dict_users_train)
# print("*"*100)
# print(dict_users_train[0])
# print("*" * 20)
# print(type(dict_users_train), type(dict_users_train[0]), type(dict_users_train[0][0]), type(dict_users_train[0][1]))
# print("*" * 20)
# print(len(dict_users_train))
model = Training_all(args, logger, dataset_train, dataset_test, dict_users_train, dict_users_test, users_used = users_used)
if args.auto_deploy:
try:
model.train()
with open(os.path.join(args.log_dir, "pickle.pkl"), "wb") as f:
pickle.dump(model.pickle_record, f)
os.makedirs( os.path.join(args.log_dir, "done"), exist_ok = True)
except Exception as e:
logger.info("error info: {}.".format(e))
else:
model.train()
with open(os.path.join(args.log_dir, "pickle.pkl"), "wb") as f:
pickle.dump(model.pickle_record, f)
os.makedirs( os.path.join(args.log_dir, "done"), exist_ok = True)