Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test #25

Open
wants to merge 1 commit into
base: integration
Choose a base branch
from
Open

Test #25

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added CIFAR10/output/confusion_matrix_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added CIFAR10/output/confusion_matrix_final.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
130 changes: 130 additions & 0 deletions CIFAR10/output/metrics_0.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
{
"epoch": 0,
"loss": 0.0213891339302063,
"accuracy": 0.5032,
"precision": 0.503870595423183,
"recall": 0.5046536204137805,
"f1": 0.499071014270288,
"confusion_matrix": [
[
252,
35,
30,
11,
15,
6,
10,
11,
120,
18
],
[
15,
327,
2,
6,
6,
1,
24,
6,
37,
86
],
[
57,
7,
171,
31,
114,
23,
57,
20,
14,
5
],
[
16,
8,
72,
142,
34,
101,
87,
36,
17,
15
],
[
29,
4,
78,
19,
217,
17,
72,
35,
10,
6
],
[
5,
6,
59,
69,
49,
209,
43,
49,
4,
6
],
[
4,
4,
21,
18,
92,
10,
325,
8,
9,
6
],
[
7,
10,
26,
21,
65,
37,
32,
251,
5,
28
],
[
71,
42,
7,
4,
2,
6,
6,
1,
324,
20
],
[
25,
92,
6,
2,
2,
1,
25,
9,
47,
298
]
]
}
129 changes: 129 additions & 0 deletions CIFAR10/output/metrics_final.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
{
"loss": 0.021847880232334138,
"accuracy": 0.4852,
"precision": 0.4945647953890349,
"recall": 0.4852,
"f1": 0.4678093676224666,
"confusion_matrix": [
[
211,
34,
108,
22,
17,
7,
41,
14,
488,
58
],
[
13,
629,
2,
4,
3,
3,
23,
8,
140,
175
],
[
28,
21,
234,
84,
264,
64,
141,
68,
72,
24
],
[
8,
19,
92,
327,
74,
115,
204,
68,
35,
58
],
[
21,
10,
96,
52,
457,
28,
172,
118,
35,
11
],
[
4,
11,
77,
214,
83,
320,
108,
132,
34,
17
],
[
3,
21,
32,
45,
107,
7,
723,
26,
14,
22
],
[
7,
9,
31,
77,
80,
62,
62,
592,
22,
58
],
[
19,
55,
29,
11,
3,
2,
20,
7,
794,
60
],
[
14,
204,
5,
13,
4,
2,
45,
20,
128,
565
]
]
}
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,6 @@ To include the dependency build files in VSCode, add the following to the includ
"${workspaceFolder}/bazel-PeerToPeer/external/{your_target_directory}"
]
```


protoc --python_out=ml proto/payload.proto
Empty file added ml/__init__.py
Empty file.
24 changes: 16 additions & 8 deletions ml/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from dataloader import CIFAR10Dataset, get_data_loaders
from utils import train, val, test

from proto import payload_pb2, utility_pb2


def nn_aggregator(state_dicts):
"""
Expand All @@ -35,8 +37,9 @@ def nn_aggregator(state_dicts):

def main():
# Set up the context and responder socket
port_send = int(input("Enter the ZMQ sender port number: "))
port_rec = port_send + 1
port_rec = int(input("Enter the ZMQ sender port number: "))
port_send = int(input("Enter the ZMQ reciever port number: "))
num_peers = int(input("Enter the number of peers: "))

context = zmq.Context()
responder = context.socket(zmq.REP)
Expand All @@ -48,23 +51,28 @@ def main():
sender.connect("tcp://localhost:" + str(port_send))

# recieve the models from fake_peer.py
print("Waiting for models...")
state_dicts = []
for i in range(3):
for i in range(num_peers):
sd = responder.recv()
sd = pickle.loads(sd)
state_dicts.append(sd)
responder.send_string("ACK")
agg_inp = payload_pb2.AggregatorInputData()
agg_inp.ParseFromString(sd)
agg_inp = pickle.loads(agg_inp.modelStateDict)
state_dicts.append(agg_inp)

# average the models
print("Averaging models...")
avg_state_dict = nn_aggregator(state_dicts)

# send the averaged model back to fake_peer.py
avg_model = SimpleCNN()
avg_model.load_state_dict(avg_state_dict)
avg_model = pickle.dumps(avg_model)

sender.send(avg_model)
_ = sender.recv_string()
print("Sending averaged model...")
tr = payload_pb2.AggregatorInputData()
tr.modelStateDict = avg_model
sender.send(tr.SerializeToString())

print("Model averaging complete.")
return
Expand Down
56 changes: 56 additions & 0 deletions ml/proto/payload_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading