From 1e400d61a0036a17387a14dacdb6f898734fc3e2 Mon Sep 17 00:00:00 2001 From: josh-freeman Date: Tue, 9 Jul 2024 17:35:20 -0400 Subject: [PATCH 1/5] feat: add new tests for aggregation. Note: the variables expected and equals are useless boilerplate for a short function (a variable should only be declared if it is used at least twice, except in very few exceptions that need excplicit justification --- discojs/src/aggregator.spec.ts | 2 +- discojs/src/weights/aggregation.spec.ts | 53 ++++++++++++++++++++----- 2 files changed, 45 insertions(+), 10 deletions(-) diff --git a/discojs/src/aggregator.spec.ts b/discojs/src/aggregator.spec.ts index 7a1e378e9..f1beaa707 100644 --- a/discojs/src/aggregator.spec.ts +++ b/discojs/src/aggregator.spec.ts @@ -20,7 +20,7 @@ AGGREGATORS.forEach(([name, Aggregator]) => describe(`${name} implements Aggregator contract`, () => { it("starts at round zero", () => { const aggregator = new Aggregator(); - + expect(aggregator.round).to.equal(0); }); diff --git a/discojs/src/weights/aggregation.spec.ts b/discojs/src/weights/aggregation.spec.ts index 6b507ad68..4751708a1 100644 --- a/discojs/src/weights/aggregation.spec.ts +++ b/discojs/src/weights/aggregation.spec.ts @@ -4,33 +4,68 @@ import { WeightsContainer, aggregation } from './index.js' describe('weights aggregation', () => { it('avg of weights with two operands', () => { - const actual = aggregation.avg([ + const expected = aggregation.avg([ WeightsContainer.of([1, 2, 3, -1], [-5, 6]), WeightsContainer.of([2, 3, 7, 1], [-10, 5]), WeightsContainer.of([3, 1, 5, 3], [-15, 19]) ]) - const expected = WeightsContainer.of([2, 2, 5, 1], [-10, 10]) + const actual = WeightsContainer.of([2, 2, 5, 1], [-10, 10]) - assert.isTrue(actual.equals(expected)) + assert.isTrue(expected.equals(actual)) }) it('sum of weights with two operands', () => { - const actual = aggregation.sum([ + const expected = aggregation.sum([ [[3, -4], [9]], [[2, 13], [0]] ]) - const expected = WeightsContainer.of([5, 9], [9]) + const actual = WeightsContainer.of([5, 9], [9]) - assert.isTrue(actual.equals(expected)) + assert.isTrue(expected.equals(actual)) }) it('diff of weights with two operands', () => { - const actual = aggregation.diff([ + const expected = aggregation.diff([ [[3, -4, 5], [9, 1]], [[2, 13, 4], [0, 1]] ]) - const expected = WeightsContainer.of([1, -17, 1], [9, 0]) + const actual = WeightsContainer.of([1, -17, 1], [9, 0]) - assert.isTrue(actual.equals(expected)) + assert.isTrue(expected.equals(actual)) }) + + it('avg of weights with no operands throws an error', () => { + assert.throws(() => aggregation.avg([])) + }) + + it('sum of weights with no operands throws an error', () => { + assert.throws(() => aggregation.sum([])) + }) + + it('diff of weights with no operands throws an error', () => { + assert.throws(() => aggregation.diff([])) + }) + + it('avg of weights with different dimensions throws an error', () => { + assert.throws(() => aggregation.avg([ + [[3, -4], [9]], + [[2, 13, 4], [0, 1]] + ])) + }) + + it('sum of weights with different dimensions throws an error', () => { + assert.throws(() => aggregation.sum([ + [[3, -4], [9]], + [[2, 13, 4], [0, 1]] + ])) + }) + + it('diff of weights with different dimensions throws an error', () => { + assert.throws(() => aggregation.diff([ + [[3, -4], [9]], + [[2, 13, 4], [0, 1]] + ])) + }) + + }) From 4d2bd7f4b2cd1700a9414f0ec14e417f5f5dd8a5 Mon Sep 17 00:00:00 2001 From: josh-freeman Date: Thu, 11 Jul 2024 17:04:09 -0400 Subject: [PATCH 2/5] feat: tests of aggregator classes --- discojs/src/aggregator.spec.ts | 73 +++++++++++++++++++++++++++++++++- package-lock.json | 12 +++++- package.json | 3 +- 3 files changed, 84 insertions(+), 4 deletions(-) diff --git a/discojs/src/aggregator.spec.ts b/discojs/src/aggregator.spec.ts index f1beaa707..fcfb9863e 100644 --- a/discojs/src/aggregator.spec.ts +++ b/discojs/src/aggregator.spec.ts @@ -1,6 +1,6 @@ import { expect } from "chai"; import { Map, Range, Set } from "immutable"; - +import { mock, instance, when, anything } from 'ts-mockito'; import { Model, WeightsContainer } from "./index.js"; import { Aggregator, @@ -24,6 +24,49 @@ AGGREGATORS.forEach(([name, Aggregator]) => expect(aggregator.round).to.equal(0); }); + it("starts with no contributions", () => { + const aggregator = new Aggregator(); + + expect(aggregator.size).to.equal(0); + }) + + it("model correctly initialized", async () => { + const aggregator = new Aggregator(); + + const model = mock(Model); + + aggregator.setModel(instance(model)); + expect(aggregator.model).to.equal(instance(model)); + }); + + it("is full when created with no nodes", () => { + const aggregator = new Aggregator(); + + expect(aggregator.isFull()); + }); + + it("is not full when created with more than no nodes and empty", () => { + const aggregator = new Aggregator(); + aggregator.setNodes(Set.of("client 0")); + + expect(aggregator.isFull()).to.be.false; + }); + + + + it("is full when enough contributions", () => { + const aggregator = new Aggregator(); + aggregator.setNodes(Set.of("client 0", "client 1", "client 2")); + + for (let i = 0; i < 3; i++) + for (let r = 0; r < aggregator.communicationRounds; r++) + aggregator.add(`client ${i}`, WeightsContainer.of([i]), 0, r); + + expect(aggregator.isFull()).to.be.true; + }); + + + it("moves forward with enough contributions", async () => { const aggregator = new Aggregator(); aggregator.setNodes(Set.of("client 0", "client 1", "client 2")); @@ -35,10 +78,36 @@ AGGREGATORS.forEach(([name, Aggregator]) => aggregator.add(`client ${i}`, WeightsContainer.of([i]), 0, r); await results; // nothing to test - + expect(aggregator.round).to.equal(1); }); + it("does not move forward with not enough contributions", async () => { + const aggregator = new Aggregator(); + aggregator.setNodes(Set.of("client 0", "client 1", "client 2")); + + const results = aggregator.receiveResult(); + + for (let i = 0; i < 2; i++) + for (let r = 0; r < aggregator.communicationRounds; r++) + aggregator.add(`client ${i}`, WeightsContainer.of([i]), 0, r); + + await results; + + expect(aggregator.round).to.equal(0); + }); + + it("Adding at the wrong round does not count", () => { + const aggregator = new Aggregator(); + aggregator.setNodes(Set.of("client 0", "client 1", "client 2")); + + for (let i = 0; i < 3; i++) + for (let r = 0; r < aggregator.communicationRounds; r++) + aggregator.add(`client ${i}`, WeightsContainer.of([i]), aggregator.round+1, r); + + expect(aggregator.size).to.equal(0); + }); + it("gives same results on each node", async () => { const network = setupNetwork(Aggregator); diff --git a/package-lock.json b/package-lock.json index f7646864d..12dfa4901 100644 --- a/package-lock.json +++ b/package-lock.json @@ -15,7 +15,8 @@ "webapp" ], "dependencies": { - "immutable": "4" + "immutable": "4", + "ts-mockito": "^2.6.1" }, "devDependencies": { "@typescript-eslint/eslint-plugin": "7", @@ -11054,6 +11055,15 @@ "integrity": "sha512-Y/arvbn+rrz3JCKl9C4kVNfTfSm2/mEp5FSz5EsZSANGPSlQrpRI5M4PKF+mJnE52jOO90PnPSc3Ur3bTQw0gA==", "dev": true }, + "node_modules/ts-mockito": { + "version": "2.6.1", + "resolved": "https://registry.npmjs.org/ts-mockito/-/ts-mockito-2.6.1.tgz", + "integrity": "sha512-qU9m/oEBQrKq5hwfbJ7MgmVN5Gu6lFnIGWvpxSjrqq6YYEVv+RwVFWySbZMBgazsWqv6ctAyVBpo9TmAxnOEKw==", + "license": "MIT", + "dependencies": { + "lodash": "^4.17.5" + } + }, "node_modules/ts-node": { "version": "10.9.2", "resolved": "https://registry.npmjs.org/ts-node/-/ts-node-10.9.2.tgz", diff --git a/package.json b/package.json index c43d094d7..c2e2bf7f5 100644 --- a/package.json +++ b/package.json @@ -11,7 +11,8 @@ "webapp" ], "dependencies": { - "immutable": "4" + "immutable": "4", + "ts-mockito": "^2.6.1" }, "devDependencies": { "@typescript-eslint/eslint-plugin": "7", From 26ddd82e4341d923a4e7048a4a48f57b6732c9e6 Mon Sep 17 00:00:00 2001 From: josh-freeman Date: Mon, 15 Jul 2024 19:24:54 -0400 Subject: [PATCH 3/5] feat: decentralized spec ts test revived. Not sure if this test is perfect; there are some timeouts... --- discojs/src/aggregator.spec.ts | 4 +- server/tests/client/decentralized.spec.ts | 72 ++++++++++++----------- 2 files changed, 40 insertions(+), 36 deletions(-) diff --git a/discojs/src/aggregator.spec.ts b/discojs/src/aggregator.spec.ts index fcfb9863e..14997ab01 100644 --- a/discojs/src/aggregator.spec.ts +++ b/discojs/src/aggregator.spec.ts @@ -1,6 +1,6 @@ import { expect } from "chai"; import { Map, Range, Set } from "immutable"; -import { mock, instance, when, anything } from 'ts-mockito'; +import { mock, instance } from 'ts-mockito'; import { Model, WeightsContainer } from "./index.js"; import { Aggregator, @@ -30,7 +30,7 @@ AGGREGATORS.forEach(([name, Aggregator]) => expect(aggregator.size).to.equal(0); }) - it("model correctly initialized", async () => { + it("model correctly initialized", () => { const aggregator = new Aggregator(); const model = mock(Model); diff --git a/server/tests/client/decentralized.spec.ts b/server/tests/client/decentralized.spec.ts index 8818cdcd7..cf6c6e0f4 100644 --- a/server/tests/client/decentralized.spec.ts +++ b/server/tests/client/decentralized.spec.ts @@ -2,12 +2,14 @@ import type * as http from 'http' import type { Task } from '@epfml/discojs' import { aggregator as aggregators, client as clients, defaultTasks } from '@epfml/discojs' - +import { expect } from 'chai' import { startServer } from '../../src/index.js' +import { WeightsContainer } from '@epfml/discojs/dist/index.js' +import { List } from 'immutable' const TASK = defaultTasks.titanic.getTask() -function test ( +function test( name: string, Client: new (url: URL, task: Task, aggregator: aggregators.Aggregator) => clients.Client, Aggregator: new () => aggregators.Aggregator @@ -28,38 +30,40 @@ function test ( await client.disconnect() }) - // TODO @s314cy: update - // it('connect to other nodes', async () => { - // const users = List(await Promise.all([ - // getClient(Client, server, TASK, new Aggregator(TASK)), - // getClient(Client, server, TASK, new Aggregator(TASK)), - // getClient(Client, server, TASK, new Aggregator(TASK)) - // ])) - // try { - // await Promise.all(users.map(async (u) => await u.connect())) - - // const wss = List.of( - // WeightsContainer.of(tf.tensor(0)), - // WeightsContainer.of(tf.tensor(1)), - // WeightsContainer.of(tf.tensor(2)) - // ) - - // const tis = users.map(() => new informant.DecentralizedInformant(TASK, 0)) - - // // wait for others to connect - // await new Promise((resolve) => setTimeout(resolve, 1_000)) - - // await Promise.all( - // users.zip(wss, tis) - // .map(async ([u, ws, ti]) => await u.onRoundBeginCommunication(ws, 0, ti)) - // .toArray() - // ) - - // tis.forEach((ti) => expect(users.size).to.eq(ti.participants())) - // } finally { - // await Promise.all(users.map(async (u) => await u.disconnect())) - // } - // }) + it('connect to other nodes', async () => { + + const clients = List(await Promise.all([ + new Client(url, TASK, new Aggregator()), + new Client(url, TASK, new Aggregator()), + new Client(url, TASK, new Aggregator()) + ])) + + try { + await Promise.all(clients.map(async (c) => await c.connect())) + + // Ensure the weights are properly typed as WeightsContainer instances + const wss = List.of( + WeightsContainer.of([0]), + WeightsContainer.of([1]), + WeightsContainer.of([2]) + ) + + + await new Promise((resolve) => setTimeout(resolve, 1_000)) + + await Promise.all( + clients.zip(wss) + .map(async ([c, ws]) => await c.onRoundBeginCommunication(ws as WeightsContainer, 0)) + .toArray() + ) + + clients.forEach((client) => expect(clients.size).to.eq(client.nodes.size)) + } finally { + await Promise.all(clients.map(async (c) => await c.disconnect())) + } + + }) + }) } From 536bb53fbbc7b5565b45a862620bf0446774d6ad Mon Sep 17 00:00:00 2001 From: josh-freeman Date: Tue, 16 Jul 2024 13:26:01 -0400 Subject: [PATCH 4/5] fix: decentralized spec ts test without timeout --- discojs/src/client/federated/base.ts | 6 ++ server/tests/client/decentralized.spec.ts | 18 +++++- server/tests/router/decentralized.spec.ts | 70 ++++++++++++----------- 3 files changed, 58 insertions(+), 36 deletions(-) diff --git a/discojs/src/client/federated/base.ts b/discojs/src/client/federated/base.ts index f4581b316..c0e3179d0 100644 --- a/discojs/src/client/federated/base.ts +++ b/discojs/src/client/federated/base.ts @@ -76,6 +76,8 @@ export class Base extends Client { const received = await waitMessageWithTimeout( this.server, type.AssignNodeID, + undefined, + "Timeout waiting for own id" ); console.info(`[${received.id}] assign id generated by the server`); this._ownId = received.id; @@ -121,6 +123,8 @@ export class Base extends Client { const { payload, round } = await waitMessageWithTimeout( this.server, type.ReceiveServerPayload, + undefined, + "Timeout waiting for server payload" ); const serverRound = round; @@ -162,6 +166,8 @@ export class Base extends Client { const received = await waitMessageWithTimeout( this.server, type.ReceiveServerMetadata, + undefined, + "Timeout waiting for metadata map", ); if (received.metadataMap !== undefined) { this.metadataMap = Map( diff --git a/server/tests/client/decentralized.spec.ts b/server/tests/client/decentralized.spec.ts index cf6c6e0f4..578389b49 100644 --- a/server/tests/client/decentralized.spec.ts +++ b/server/tests/client/decentralized.spec.ts @@ -7,7 +7,7 @@ import { startServer } from '../../src/index.js' import { WeightsContainer } from '@epfml/discojs/dist/index.js' import { List } from 'immutable' -const TASK = defaultTasks.titanic.getTask() +const TASK = defaultTasks.mnist.getTask() function test( name: string, @@ -41,7 +41,6 @@ function test( try { await Promise.all(clients.map(async (c) => await c.connect())) - // Ensure the weights are properly typed as WeightsContainer instances const wss = List.of( WeightsContainer.of([0]), WeightsContainer.of([1]), @@ -57,7 +56,20 @@ function test( .toArray() ) - clients.forEach((client) => expect(clients.size).to.eq(client.nodes.size)) + // send contribution to peers that are waiting: is done via onRoundEndCommunication: + // + + + clients.forEach((client) => expect(clients.size, 'all nodes should be connected to each other').to.eq(client.nodes.size)) + + await Promise.all( + clients.zip(wss) + .map(async ([c, ws]) => await c.onRoundEndCommunication(ws as WeightsContainer, 0)) + .toArray() + ) + + clients.forEach((client) => expect(0, 'all nodes should be disconnected at the end of a round').to.eq(client.nodes.size)) + } finally { await Promise.all(clients.map(async (c) => await c.disconnect())) } diff --git a/server/tests/router/decentralized.spec.ts b/server/tests/router/decentralized.spec.ts index 0841c353a..6e492bddf 100644 --- a/server/tests/router/decentralized.spec.ts +++ b/server/tests/router/decentralized.spec.ts @@ -1,41 +1,44 @@ // import { agent as request } from 'supertest' +import { WeightsContainer } from "@epfml/discojs" + // import { serialization, WeightsContainer } from '@epfml/discojs-web' // import { getApp } from '../../src/get_server' -// const platformID = 'deai' -// const clients = { -// one: 'one', -// two: 'two' -// } -// const task = 'titanic' - -// const weights = WeightsContainer.of([1, 1], [1, 1]) - -// const newRound = 1 - -// function connectHeader ( -// platformID: string, -// taskID: string, -// clientID: string -// ): string { -// return `/${platformID}/connect/${taskID}/${clientID}` -// } - -// function disconnectHeader ( -// platformID: string, -// taskID: string, -// clientID: string -// ): string { -// return `/${platformID}/disconnect/${taskID}/${clientID}` -// } - -// describe(`${platformID} simple connection tests`, function () { -// this.timeout(30_000) - -// it('connect and then disconnect to valid task', async () => { +const platformID = 'deai' +const clients = { + one: 'one', + two: 'two' + } + const task = 'titanic' + + const weights = WeightsContainer.of([1, 1], [1, 1]) + + const newRound = 1 + +function connectHeader ( + platformID: string, + taskID: string, + clientID: string + ): string { + return `/${platformID}/connect/${taskID}/${clientID}` + } + +function disconnectHeader ( + platformID: string, + taskID: string, + clientID: string + ): string { + return `/${platformID}/disconnect/${taskID}/${clientID}` + } + +describe(`${platformID} simple connection tests`, function () { + this.timeout(30_000) + + it('connect and then disconnect to valid task', async () => { // const app = await getApp() + // await request(app) // .get(connectHeader(platformID, task, clients.one)) @@ -43,7 +46,7 @@ // await request(app) // .get(disconnectHeader(platformID, task, clients.one)) // .expect(200) -// }) + }) // it('connect to non existing task', async () => { // // the single test @@ -80,4 +83,5 @@ // }) // // TODO: Add a test with a whole round, etc -// }) +}) + From cef523b213fbefbd5f0b3e24de1d77b5dc11ca19 Mon Sep 17 00:00:00 2001 From: josh-freeman Date: Tue, 16 Jul 2024 19:32:22 -0400 Subject: [PATCH 5/5] feat: test decentralized spec ts: connect and disconnect from existing task, and check that connection fails for non existing task. Note: I need to figure out how to run the CI scripts before I commit. My code is unclean. --- server/src/router/decentralized/server.ts | 39 ++++++-- server/src/router/router.ts | 4 +- server/tests/router/decentralized.spec.ts | 108 +++++++++++++--------- 3 files changed, 97 insertions(+), 54 deletions(-) diff --git a/server/src/router/decentralized/server.ts b/server/src/router/decentralized/server.ts index dba00779e..d973984a7 100644 --- a/server/src/router/decentralized/server.ts +++ b/server/src/router/decentralized/server.ts @@ -31,14 +31,37 @@ export class Decentralized extends Server { public isValidUrl (url: string | undefined): boolean { const splittedUrl = url?.split('/') - return ( - splittedUrl !== undefined && - splittedUrl.length === 3 && - splittedUrl[0] === '' && - this.isValidTask(splittedUrl[1]) && - this.isValidWebSocket(splittedUrl[2]) - ) - } + // Assuming this code is inside a method of a class + console.log("Evaluating URL validation..."); + + if (splittedUrl === undefined) { + console.log("Failure: URL is undefined."); + return false; + } + + if (splittedUrl.length !== 3) { + console.log(`Failure: URL does not have exactly 3 segments. Found ${splittedUrl.length} segments: ${splittedUrl}`); + return false; + } + + if (splittedUrl[0] !== '') { + console.log("Failure: URL does not start with a '/'."); + return false; + } + + if (!this.isValidTask(splittedUrl[1])) { + console.log(`Failure: '${splittedUrl[1]}' is not a valid task.`); + return false; + } + + if (!this.isValidWebSocket(splittedUrl[2])) { + console.log(`Failure: '${splittedUrl[2]}' is not a valid WebSocket.`); + return false; + } + + console.log("URL is valid."); + return true; + } protected initTask (): void {} diff --git a/server/src/router/router.ts b/server/src/router/router.ts index 6a42a4165..16697ceb4 100644 --- a/server/src/router/router.ts +++ b/server/src/router/router.ts @@ -23,15 +23,17 @@ export class Router { const decentralized = new Decentralized(wsApplier, this.tasksAndModels) this.ownRouter = express.Router() - wsApplier.applyTo(this.ownRouter) + wsApplier.applyTo(this.ownRouter) process.nextTick(() => wsApplier.getWss().on('connection', (ws, req) => { if (!federated.isValidUrl(req.url) && !decentralized.isValidUrl(req.url)) { console.log('Connection refused') + ws.send('404') ws.terminate() ws.close() } + ws.send('200') }) ) diff --git a/server/tests/router/decentralized.spec.ts b/server/tests/router/decentralized.spec.ts index 6e492bddf..18635e7be 100644 --- a/server/tests/router/decentralized.spec.ts +++ b/server/tests/router/decentralized.spec.ts @@ -1,6 +1,11 @@ -// import { agent as request } from 'supertest' +// for get requests, we can use the library cypress -import { WeightsContainer } from "@epfml/discojs" +import { request, response } from "express" + +import { client, WeightsContainer } from "@epfml/discojs" +import { startServer } from "../../src/index.js" + +import { expect } from "chai" // import { serialization, WeightsContainer } from '@epfml/discojs-web' @@ -8,53 +13,67 @@ import { WeightsContainer } from "@epfml/discojs" const platformID = 'deai' const clients = { - one: 'one', - two: 'two' - } - const task = 'titanic' - - const weights = WeightsContainer.of([1, 1], [1, 1]) - - const newRound = 1 - -function connectHeader ( - platformID: string, - taskID: string, - clientID: string - ): string { - return `/${platformID}/connect/${taskID}/${clientID}` - } - -function disconnectHeader ( - platformID: string, - taskID: string, - clientID: string - ): string { - return `/${platformID}/disconnect/${taskID}/${clientID}` - } + one: 'one', + two: 'two' +} +const taskID = 'titanic' + +const weights = WeightsContainer.of([1, 1], [1, 1]) + +const newRound = 1 + +function connectHeader( + platformID: string, + taskID: string, + clientID: string +): string { + return `/${platformID}/connect` +} + +function disconnectHeader( + platformID: string, + taskID: string, + clientID: string +): string { + return `/${platformID}/disconnect`} + +import { Server } from "http"; + +let server: Server; +let url: URL + +beforeEach(async () => { + [server, url] = await startServer(); + +});afterEach(async () => {server.close()}) describe(`${platformID} simple connection tests`, function () { - this.timeout(30_000) + this.timeout(30_000) - it('connect and then disconnect to valid task', async () => { -// const app = await getApp() - + it('connect and then disconnect to valid task', async () => { + const ws = await new WebSocket(url+'titanic'); -// await request(app) -// .get(connectHeader(platformID, task, clients.one)) -// .expect(200) -// await request(app) -// .get(disconnectHeader(platformID, task, clients.one)) -// .expect(200) - }) -// it('connect to non existing task', async () => { -// // the single test -// const app = await getApp() -// await request(app) -// .get(connectHeader(platformID, 'fakeTask', clients.one)) -// .expect(404) -// }) + ws.onmessage = await function incoming(data) { + expect(data.data).to.equal('200'); + ws.close(); + + } + + }) + + +}) + + it('connect to non existing task', async () => { + const ws = await new WebSocket(url+'nonexistingtask'); + + ws.onmessage = await function incoming(data) { + expect(data.data).to.equal('404'); + ws.close(); + } + + }) // }) // describe(`${platformID} weight sharing tests`, function () { @@ -83,5 +102,4 @@ describe(`${platformID} simple connection tests`, function () { // }) // // TODO: Add a test with a whole round, etc -})