diff --git a/cli/src/benchmark_gpt.ts b/cli/src/benchmark_gpt.ts index a80d07e79..f4faab276 100644 --- a/cli/src/benchmark_gpt.ts +++ b/cli/src/benchmark_gpt.ts @@ -55,8 +55,7 @@ async function main(args: Required): Promise { contextLength, batchSize, modelPath } = args // Launch a server instance - const disco = await Server.of(defaultTasks.wikitext); - const [server, url] = await disco.serve(); + const [server, url] = await new Server().serve(undefined, defaultTasks.wikitext); // Fetch the wikitext task from the server const tasks = await fetchTasks(url) diff --git a/cli/src/cli.ts b/cli/src/cli.ts index cb470b275..4591a80a6 100644 --- a/cli/src/cli.ts +++ b/cli/src/cli.ts @@ -39,8 +39,7 @@ async function main (provider: TaskProvider, numberOfUsers: number): Promise const results = new Promise((resolve) => aggregator.on("aggregation", resolve), ); - + + let promises = List>() for (let i = 0; i < 3; i++) - for (let r = 0; r < aggregator.communicationRounds; r++) - aggregator.add(`client ${i}`, WeightsContainer.of([i]), 0, r); - + for (let r = 0; r < aggregator.communicationRounds; r++){ + promises = promises.push(aggregator.getPromiseForAggregation()) + aggregator.add(`client ${i}`, WeightsContainer.of([i]), 0, r) + } + await Promise.all(promises) await results; // nothing to test expect(aggregator.round).to.equal(1); @@ -57,7 +60,7 @@ AGGREGATORS.forEach(([name, Aggregator]) => id, [agg, WeightsContainer.of([ws])], ]), - ), + ), 0 ) ) .valueSeq() @@ -94,6 +97,7 @@ export function setupNetwork( // run all rounds of communication export async function communicate( networkWithContributions: Map, + aggregationRound: number ): Promise> { const communicationsRound = networkWithContributions.first()?.[0].communicationRounds; @@ -125,7 +129,7 @@ export async function communicate( agg .makePayloads(contrib) .entrySeq() - .forEach(([to, payload]) => network.get(to)?.add(id, payload, 0, r)), + .forEach(([to, payload]) => network.get(to)?.add(id, payload, aggregationRound, r)), ); contributions = Map(await Promise.all(nextContributions)); diff --git a/discojs/src/aggregator/base.ts b/discojs/src/aggregator/aggregator.ts similarity index 69% rename from discojs/src/aggregator/base.ts rename to discojs/src/aggregator/aggregator.ts index 4a036f137..41611a0f4 100644 --- a/discojs/src/aggregator/base.ts +++ b/discojs/src/aggregator/aggregator.ts @@ -1,7 +1,7 @@ import createDebug from "debug"; import { Map, Set } from 'immutable' -import type { client } from '../index.js' +import type { client, WeightsContainer } from '../index.js' import { EventEmitter } from '../utils/event_emitter.js' @@ -17,10 +17,10 @@ export enum AggregationStep { * Main, abstract, aggregator class whose role is to buffer contributions and to produce * a result based off their aggregation, whenever some defined condition is met. * - * Emits an event whenever an aggregation step is performed. - * Users wait for this event to fetch the aggregation result. + * Emits an event whenever an aggregation step is performed with the counrd's aggregated weights. + * Users subscribes to this event to get the aggregation result. */ -export abstract class Base extends EventEmitter<{'aggregation': T }> { +export abstract class Aggregator extends EventEmitter<{'aggregation': WeightsContainer }> { /** * Contains the ids of all active nodes, i.e. members of the aggregation group at * a given round. It is a subset of all the nodes available in the network. @@ -31,8 +31,8 @@ export abstract class Base extends EventEmitter<{'aggregation': T }> { * It defines the effective aggregation group, which is possibly a subset * of all active nodes, depending on the aggregation scheme. */ - // communication round -> NodeID -> T - protected contributions: Map> + // communication round -> NodeID -> WeightsContainer + protected contributions: Map> /** * The current aggregation round, used for assessing whether a node contribution is recent enough @@ -61,29 +61,63 @@ export abstract class Base extends EventEmitter<{'aggregation': T }> { this.contributions = Map() this._nodes = Set() + } - // On every aggregation, update the object's state to match the current aggregation - // and communication rounds. - this.on('aggregation', () => this.nextRound()) + /** + * Convenience method to subscribe to the 'aggregation' event. + * Await this promise returns the aggregated weights for the current round. + * + * @returns a promise for the aggregated weights + */ + getPromiseForAggregation(): Promise { + return new Promise((resolve) => this.once('aggregation', resolve)); } /** * Adds a node's contribution to the aggregator for the given aggregation and communication rounds. * The aggregation round is increased whenever a new global model is obtained and local models are updated. * Within one aggregation round there may be multiple communication rounds (such as for the decentralized secure aggregation - * which requires multiple steps to obtain a global model) - * The contribution will be aggregated during the next aggregation step. + * which requires multiple steps to obtain a global model) + * The contribution is aggregated during the next aggregation step. + * * @param nodeId The node's id * @param contribution The node's contribution - * @param round aggregation round of the contribution was made - * @param communicationRound communication round the contribution was made within the aggregation round - * @returns boolean, true if the contribution has been successfully taken into account or False if it has been rejected */ - abstract add (nodeId: client.NodeID, contribution: T, round: number, communicationRound?: number): boolean + add(nodeId: client.NodeID, contribution: WeightsContainer, + aggregationRound: number, communicationRound?: number): void { + if (!this.isValidContribution(nodeId, aggregationRound)) + throw new Error("Tried adding an invalid contribution. Handle this case before calling add.") + + // call the abstract method _add, implemented by subclasses + this._add(nodeId, contribution, communicationRound) + // If the aggregator has enough contributions then aggregate the weights + // and emit the 'aggregation' event + if (this.isFull()) { + const aggregatedWeights = this.aggregate() + // On each aggregation, increment the communication round + // If all communication rounds were performed, proceed to the next aggregation round + // and empty the past contributions. + this._communicationRound++; + if (this.communicationRound === this.communicationRounds) { + this._communicationRound = 0 + this._round++; + this.contributions = Map() + } + // Emitting the 'aggregation' communicates the weights to subscribers + this.emit('aggregation', aggregatedWeights) + } + } + + // Abstract method to be implemented by subclasses + // Handles logging and adding the contribution to the list of the current round's contributions + protected abstract _add(nodeId: client.NodeID, contribution: WeightsContainer, communicationRound?: number): void /** * Evaluates whether a given participant contribution can be used in the current aggregation round * the boolean returned by `this.add` is obtained via `this.isValidContribution` + * + * @param nodeId the node id of the contribution to be added + * @param round the aggregation round of the contribution to be added */ isValidContribution(nodeId: client.NodeID, round: number): boolean { if (!this.nodes.has(nodeId)) { @@ -101,7 +135,7 @@ export abstract class Base extends EventEmitter<{'aggregation': T }> { * Performs an aggregation step over the received node contributions. * Must store the aggregation's result in the aggregator's result promise. */ - abstract aggregate (): void + protected abstract aggregate (): WeightsContainer /** * Returns whether the given round is recent enough, dependent on the @@ -109,7 +143,7 @@ export abstract class Base extends EventEmitter<{'aggregation': T }> { * @param round The round * @returns True if the round is recent enough, false otherwise */ - isWithinRoundCutoff (round: number): boolean { + private isWithinRoundCutoff (round: number): boolean { return this.round - round <= this.roundCutoff } @@ -172,14 +206,6 @@ export abstract class Base extends EventEmitter<{'aggregation': T }> { this._nodes = nodeIds } - /** - * Empties the current set of "nodes". Usually called at the end of an aggregation round, - * if the set of nodes is meant to change or to be actualized. - */ - resetNodes (): void { - this._nodes = Set() - } - /** * Sets the aggregator's round number. To be used whenever the aggregator is out of sync * with the network's round. @@ -191,24 +217,11 @@ export abstract class Base extends EventEmitter<{'aggregation': T }> { } } - /** - * Updates the aggregator's state to proceed to the next communication round. - * If all communication rounds were performed, proceeds to the next aggregation round - * and empties the collection of stored contributions. - */ - public nextRound (): void { - if (++this._communicationRound === this.communicationRounds) { - this._communicationRound = 0 - this._round++ - this.contributions = Map() - } - } - /** * Constructs the payloads sent to other nodes as contribution. * @param base Object from which the payload is computed */ - abstract makePayloads (base: T): Map + abstract makePayloads (base: WeightsContainer): Map abstract isFull (): boolean @@ -226,17 +239,6 @@ export abstract class Base extends EventEmitter<{'aggregation': T }> { return this._round } - /** - * The aggregator's current size, defined by its number of contributions. The size is bounded by - * the amount of all active nodes times the number of communication rounds. - */ - get size (): number { - return this.contributions - .valueSeq() - .map((m) => m.size) - .reduce((totalSize: number, size) => totalSize + size) ?? 0 - } - /** * The current communication round. */ diff --git a/discojs/src/aggregator/get.ts b/discojs/src/aggregator/get.ts index eb5aec32d..076945c58 100644 --- a/discojs/src/aggregator/get.ts +++ b/discojs/src/aggregator/get.ts @@ -11,9 +11,9 @@ type AggregatorOptions = Partial<{ /** * Initializes an aggregator according to the task definition, the training scheme and the aggregator parameters. * Here is the ordered list of parameters used to define the aggregator and its default behavior: - * task.trainingInformation.aggregator > options.scheme > task.trainingInformation.scheme + * task.trainingInformation.aggregationStrategy > options.scheme > task.trainingInformation.scheme * - * If `task.trainingInformation.aggregator` is defined, we initialize the chosen aggregator with `options` parameter values. + * If `task.trainingInformation.aggregationStrategy` is defined, we initialize the chosen aggregator with `options` parameter values. * Otherwise, we default to a MeanAggregator for both training schemes. * * For the MeanAggregator we rely on `options.scheme` and fallback on `task.trainingInformation.scheme` to infer default values. @@ -27,10 +27,10 @@ type AggregatorOptions = Partial<{ * @returns The aggregator */ export function getAggregator(task: Task, options: AggregatorOptions = {}): aggregator.Aggregator { - const aggregatorType = task.trainingInformation.aggregator ?? 'mean' + const aggregationStrategy = task.trainingInformation.aggregationStrategy ?? 'mean' const scheme = options.scheme ?? task.trainingInformation.scheme - switch (aggregatorType) { + switch (aggregationStrategy) { case 'mean': if (scheme === 'decentralized') { // If options are not specified, we default to expecting a contribution from all peers, so we set the threshold to 100% diff --git a/discojs/src/aggregator/index.ts b/discojs/src/aggregator/index.ts index 868278785..3310f7d16 100644 --- a/discojs/src/aggregator/index.ts +++ b/discojs/src/aggregator/index.ts @@ -1,10 +1,5 @@ -import type { WeightsContainer } from '../weights/index.js' -import type { Base } from './base.js' - -export { Base as AggregatorBase, AggregationStep } from './base.js' +export { Aggregator, AggregationStep } from './aggregator.js' export { MeanAggregator } from './mean.js' export { SecureAggregator } from './secure.js' -export { getAggregator } from './get.js' - -export type Aggregator = Base +export { getAggregator } from './get.js' \ No newline at end of file diff --git a/discojs/src/aggregator/mean.spec.ts b/discojs/src/aggregator/mean.spec.ts index 3a352c652..dc7575efe 100644 --- a/discojs/src/aggregator/mean.spec.ts +++ b/discojs/src/aggregator/mean.spec.ts @@ -11,44 +11,113 @@ async function WSIntoArrays(ws: WeightsContainer): Promise { } describe("mean aggregator", () => { - it("updates only within round cutoff", () => { - const aggregator = new MeanAggregator(1, 3); - aggregator.setNodes( - Set.of("first client", "second client", "third client"), - ); + it("updates only within round cutoff", async () => { + const aggregator = new MeanAggregator(1, 1, 'relative'); // use a round cutoff of 1 + aggregator.setNodes(Set.of("client 1")); // round 0 - - expect(aggregator.add("first client", WeightsContainer.of(), 0)).to.be.true; - - aggregator.nextRound(); + expect(aggregator.round).to.equal(0) + expect(aggregator.isValidContribution("client 1", 0)).to.be.true; + const client1Round0Promise = aggregator.getPromiseForAggregation(); + aggregator.add("client 1", WeightsContainer.of([1]), 0); + expect(WeightsContainer.of([1]).equals(await client1Round0Promise)).to.be.true + expect(aggregator.round).to.equal(1) + // round 1 - - expect(aggregator.add("second client", WeightsContainer.of(), 0)).to.be - .true; - expect(aggregator.add("first client", WeightsContainer.of(), 1)).to.be.true; - - aggregator.nextRound(); + aggregator.registerNode("client 2"); + expect(aggregator.isValidContribution("client 2", 0)).to.be.true; // round 0 should be within the cutoff + aggregator.add("client 1", WeightsContainer.of([1]), 1); + const client2Round0Promise = aggregator.getPromiseForAggregation(); + aggregator.add("client 2", WeightsContainer.of([2]), 0); + expect(WeightsContainer.of([1.5]).equals(await client2Round0Promise)).to.be.true + expect(aggregator.round).to.equal(2) + // round 2 - - expect(aggregator.add("third client", WeightsContainer.of(), 0)).to.be - .false; - expect(aggregator.add("second client", WeightsContainer.of(), 1)).to.be - .true; - expect(aggregator.add("first client", WeightsContainer.of(), 2)).to.be.true; + aggregator.registerNode("client 3"); + expect(aggregator.isValidContribution("client 3", 0)).to.be.false; // round 0 is now out of the cutoff + expect(aggregator.isValidContribution("client 3", 1)).to.be.true; + aggregator.add("client 1", WeightsContainer.of([1]), 2); + aggregator.add("client 2", WeightsContainer.of([1]), 2); + const client3Round2Promise = aggregator.getPromiseForAggregation(); + aggregator.add("client 3", WeightsContainer.of([4]), 1); + expect(WeightsContainer.of([2]).equals(await client3Round2Promise)).to.be.true + expect(aggregator.round).to.equal(3) }); it("returns the mean of the weights", async () => { - const aggregator = new MeanAggregator(0, 2); - aggregator.setNodes(Set.of("first client", "second client")); + const aggregator = new MeanAggregator(0, 2, 'absolute'); + const [id1, id2] = ["client 1", "client 2"] + + aggregator.setNodes(Set.of(id1, id2)); const results = new Promise((resolve) => aggregator.once("aggregation", resolve), ); - aggregator.add("first client", WeightsContainer.of([0], [1]), 0); - aggregator.add("second client", WeightsContainer.of([2], [3]), 0); + const result1 = aggregator.getPromiseForAggregation(); + aggregator.add(id1, WeightsContainer.of([0], [1]), 0); + const result2 = aggregator.getPromiseForAggregation(); + aggregator.add(id2, WeightsContainer.of([2], [3]), 0); + expect((await result1).equals(await result2)).to.be.true expect(await WSIntoArrays(await results)).to.deep.equal([[1], [2]]); }); + + it("waits for 100% of the contributions by default", async () => { + const aggregator = new MeanAggregator(); + const [id1, id2] = ["client 1", "client 2"] + + aggregator.setNodes(Set.of(id1, id2)); + + const result1 = aggregator.getPromiseForAggregation(); + aggregator.add(id1, WeightsContainer.of([0], [1]), 0); + // Make sure that the aggregation isn't triggered + expect(aggregator.round).equals(0) + + aggregator.registerNode(id2); + const result2 = aggregator.getPromiseForAggregation(); + aggregator.add(id2, WeightsContainer.of([2], [3]), 0); + expect((await result1).equals(await result2)).to.be.true + expect(aggregator.round).equals(1) // round should be one now + }); + + it("can wait for an absolute number of contributions", async () => { + const aggregator = new MeanAggregator(0, 1, 'absolute'); + const [id1, id2] = ["client 1", "client 2"] + aggregator.setNodes(Set.of(id1, id2)); // register two clients + + // should aggregate with only one contribution + const result = aggregator.getPromiseForAggregation(); + aggregator.add(id1, WeightsContainer.of([0], [1]), 0); + expect(await WSIntoArrays(await result)).to.deep.equal([[0], [1]]); + }); + + it("can wait for an relative number of contributions", async () => { + const aggregator = new MeanAggregator(0, 0.5, 'relative'); + const [id1, id2] = ["client 1", "client 2"] + aggregator.setNodes(Set.of(id1, id2)); // register two clients + + // should aggregate with only 50% of the contribution (1 contribution) + const result = aggregator.getPromiseForAggregation(); + aggregator.add(id1, WeightsContainer.of([0], [1]), 0); + expect(await WSIntoArrays(await result)).to.deep.equal([[0], [1]]); + }); + + it("doesn't aggregate when not enough participants", async () => { + const aggregator = new MeanAggregator(0, 1, 'absolute'); // only wait for a single participant + aggregator.minNbOfParticipants = 2 // However the task can specify another minimum number, here 2 + const [id1, id2] = ["client 1", "client 2"] + aggregator.setNodes(Set.of(id1)); + + const result1 = aggregator.getPromiseForAggregation(); + aggregator.add(id1, WeightsContainer.of([0], [1]), 0); + // Make sure that the aggregation isn't triggered + expect(aggregator.round).equals(0) + + aggregator.registerNode(id2); + const result2 = aggregator.getPromiseForAggregation(); + aggregator.add(id2, WeightsContainer.of([2], [3]), 0); + expect((await result1).equals(await result2)).to.be.true + expect(aggregator.round).equals(1) + }); }); diff --git a/discojs/src/aggregator/mean.ts b/discojs/src/aggregator/mean.ts index 423ee6542..cea0dba9f 100644 --- a/discojs/src/aggregator/mean.ts +++ b/discojs/src/aggregator/mean.ts @@ -1,7 +1,7 @@ import createDebug from "debug"; import type { Map } from "immutable"; -import { AggregationStep, Base as Aggregator } from "./base.js"; +import { AggregationStep, Aggregator } from "./aggregator.js"; import type { WeightsContainer, client } from "../index.js"; import { aggregation } from "../index.js"; @@ -13,7 +13,7 @@ type ThresholdType = 'relative' | 'absolute' * Mean aggregator whose aggregation step consists in computing the mean of the received weights. * */ -export class MeanAggregator extends Aggregator { +export class MeanAggregator extends Aggregator { readonly #threshold: number; readonly #thresholdType: ThresholdType; #minNbOfParticipants: number | undefined; @@ -64,7 +64,7 @@ export class MeanAggregator extends Aggregator { else { // Print a warning regarding the default behavior when thresholdType is not specified if (thresholdType === undefined) { - // TODO enforce validity by splitting features instead of warning + // TODO enforce validity by splitting the different threshold types into separate classes instead of warning debug( "[WARN] Setting the aggregator's threshold to 100% of the nodes' contributions by default. " + "To instead wait for a single contribution, set thresholdType = 'absolute'" @@ -95,32 +95,17 @@ export class MeanAggregator extends Aggregator { this.#minNbOfParticipants = minNbOfParticipants } - override add( - nodeId: client.NodeID, - contribution: WeightsContainer, - round: number, - currentContributions: number = 0, - ): boolean { - if (currentContributions !== 0) - throw new Error("only a single communication round"); - - if (!this.isValidContribution(nodeId, round)) return false + override _add(nodeId: client.NodeID, contribution: WeightsContainer): void { this.log( - this.contributions.hasIn([0, nodeId]) - ? AggregationStep.UPDATE - : AggregationStep.ADD, + this.contributions.hasIn([0, nodeId]) ? AggregationStep.UPDATE : AggregationStep.ADD, nodeId, ); this.contributions = this.contributions.setIn([0, nodeId], contribution); - - if (this.isFull()) this.aggregate(); - - return true; } - override aggregate(): void { + override aggregate(): WeightsContainer { const currentContributions = this.contributions.get(0); if (currentContributions === undefined) throw new Error("aggregating without any contribution"); @@ -128,8 +113,7 @@ export class MeanAggregator extends Aggregator { this.log(AggregationStep.AGGREGATE); const result = aggregation.avg(currentContributions.values()); - // Emitting the event runs the superclass' callback to increment the round - this.emit('aggregation', result); + return result; } override makePayloads( diff --git a/discojs/src/aggregator/secure.spec.ts b/discojs/src/aggregator/secure.spec.ts index 8888cad39..35a12c3e1 100644 --- a/discojs/src/aggregator/secure.spec.ts +++ b/discojs/src/aggregator/secure.spec.ts @@ -66,7 +66,7 @@ describe("secure aggregator", () => { .entrySeq() .zip(Range(0, 3)) .map(([[id, agg], i]) => [id, [agg, WeightsContainer.of([i])]]), - ), + ), 0 ); const secureResults = await communicate( Map( @@ -74,7 +74,7 @@ describe("secure aggregator", () => { .entrySeq() .zip(Range(0, 3)) .map(([[id, agg], i]) => [id, [agg, WeightsContainer.of([i])]]), - ), + ), 0 ); List(await Promise.all(secureResults.sort().valueSeq().map(wsIntoArrays))) diff --git a/discojs/src/aggregator/secure.ts b/discojs/src/aggregator/secure.ts index 060c99612..064454e62 100644 --- a/discojs/src/aggregator/secure.ts +++ b/discojs/src/aggregator/secure.ts @@ -1,7 +1,7 @@ import { Map, List, Range } from "immutable"; import * as tf from "@tensorflow/tfjs"; -import { AggregationStep, Base as Aggregator } from "./base.js"; +import { AggregationStep, Aggregator } from "./aggregator.js"; import type { WeightsContainer, client } from "../index.js"; import { aggregation } from "../index.js"; @@ -12,12 +12,12 @@ import { aggregation } from "../index.js"; * - then, they sum their received shares and communicate the result. * Finally, nodes are able to average the received partial sums to establish the aggregation result. */ -export class SecureAggregator extends Aggregator { +export class SecureAggregator extends Aggregator { constructor(private readonly maxShareValue = 100) { super(0, 2); } - override aggregate(): void { + override aggregate(): WeightsContainer { this.log(AggregationStep.AGGREGATE); switch (this.communicationRound) { @@ -27,9 +27,7 @@ export class SecureAggregator extends Aggregator { if (currentContributions === undefined) throw new Error("aggregating without any contribution"); - const result = aggregation.sum(currentContributions.values()); - this.emit('aggregation', result); - break + return aggregation.sum(currentContributions.values()); } // Average the received partial sums case 1: { @@ -37,21 +35,18 @@ export class SecureAggregator extends Aggregator { if (currentContributions === undefined) throw new Error("aggregating without any contribution"); - const result = aggregation.avg(currentContributions.values()); - this.emit('aggregation', result); - break + return aggregation.avg(currentContributions.values()); } default: throw new Error("communication round is out of bounds"); } } - override add( + _add( nodeId: client.NodeID, contribution: WeightsContainer, - round: number, - communicationRound?: number, - ): boolean { + communicationRound: number, + ): void { switch (communicationRound) { case 0: case 1: @@ -60,23 +55,16 @@ export class SecureAggregator extends Aggregator { throw new Error("requires communication round to be 0 or 1"); } - if (!this.isValidContribution(nodeId, round)) return false - this.log( - this.contributions.hasIn([communicationRound, nodeId]) - ? AggregationStep.UPDATE - : AggregationStep.ADD, - nodeId, + this.contributions.hasIn([communicationRound, nodeId]) ? + AggregationStep.UPDATE : AggregationStep.ADD, + nodeId.slice(0, 4), ); this.contributions = this.contributions.setIn( [communicationRound, nodeId], contribution, ); - - if (this.isFull()) this.aggregate(); - - return true; } override isFull(): boolean { @@ -92,7 +80,7 @@ export class SecureAggregator extends Aggregator { switch (this.communicationRound) { case 0: { const shares = this.generateAllShares(weights); - // Abitrarily assign our shares to the available nodes + // Arbitrarily assign our shares to the available nodes return Map( List(this.nodes).zip(shares) as List<[string, WeightsContainer]>, ); diff --git a/discojs/src/client/client.ts b/discojs/src/client/client.ts index 4083b3417..ead8afa86 100644 --- a/discojs/src/client/client.ts +++ b/discojs/src/client/client.ts @@ -1,3 +1,4 @@ +import createDebug from "debug"; import axios from 'axios' import type { Model, Task, WeightsContainer, RoundStatus } from '../index.js' @@ -6,24 +7,35 @@ import type { NodeID } from './types.js' import type { EventConnection } from './event_connection.js' import type { Aggregator } from '../aggregator/index.js' import { EventEmitter } from '../utils/event_emitter.js' +import { type } from "./messages.js"; + +const debug = createDebug("discojs:client"); /** * Main, abstract, class representing a Disco client in a network, which handles * communication with other nodes, be it peers or a server. */ export abstract class Client extends EventEmitter<{'status': RoundStatus}>{ - /** - * Own ID provided by the network's server. - */ + // Own ID provided by the network's server. protected _ownId?: NodeID + // The network's server. + protected _server?: EventConnection + // The aggregator's result produced after aggregation. + protected aggregationResult?: Promise /** - * The network's server. + * When the server notifies clients to pause and wait until more + * participants join, we rely on this promise to wait + * until the server signals that the training can resume */ - protected _server?: EventConnection + protected promiseForMoreParticipants: Promise | undefined = undefined; + /** - * The aggregator's result produced after aggregation. + * When the server notifies the client that they can resume training + * after waiting for more participants, we want to be able to display what + * we were doing before waiting (training locally or updating our model). + * We use this attribute to store the status to rollback to when we stop waiting */ - protected aggregationResult?: Promise + private previousStatus: RoundStatus | undefined; constructor ( public readonly url: URL, // The network server's URL to connect to @@ -33,6 +45,18 @@ export abstract class Client extends EventEmitter<{'status': RoundStatus}>{ super() } + /** + * Communication callback called at the beginning of every training round. + */ + abstract onRoundBeginCommunication(): Promise; + + /** + * Communication callback called the end of every training round. + * @param weights The local weight update resulting for the current local training round + * @returns aggregated weights or the local weights upon error + */ + abstract onRoundEndCommunication(weights: WeightsContainer): Promise; + /** * Handles the connection process from the client to any sort of network server. * This method is overriden by the federated and decentralized clients @@ -45,8 +69,98 @@ export abstract class Client extends EventEmitter<{'status': RoundStatus}>{ /** * Handles the disconnection process of the client from any sort of network server. */ - async disconnect (): Promise {} + async disconnect(): Promise { } + + /** + * Emits the round status specified. It also stores the status emitted such that + * if the server tells the client to wait for more participants, it can display + * the waiting status and once enough participants join, it can display the previous status again + */ + protected saveAndEmit(status: RoundStatus) { + this.previousStatus = status + this.emit("status", status) + } + + /** + * For both federated and decentralized clients, we listen to the server to tell + * us whether there are enough participants to train. If not, we pause until further notice. + * When a client connects to the server, the server answers with the session information (id, + * number of participants) and whether there are enough participants. + * When there are the server sends a new EnoughParticipant message to update the client. + * + * `setMessageInversionFlag` is used to address the following scenario: + * 1. Client 1 connect to the server + * 2. Server answers with message A containing "not enough participants" + * 3. Before A arrives a new client joins. There are enough participants now. + * 4. Server updates client 1 with message B saying "there are enough participants" + * 5. Due to network and message sizes message B can arrive before A. + * i.e. "there are enough participants" arrives before "not enough participants" + * ending up with client 1 thinking it needs to wait for more participants. + * + * To keep track of this message inversion, `setMessageInversionFlag` + * tells us whether a message inversion occurred (by setting a boolean to true) + * + * @param setMessageInversionFlag function flagging whether a message inversion occurred + * between a NewNodeInfo message and an EnoughParticipant message. + */ + protected setupServerCallbacks(setMessageInversionFlag: () => void) { + // Setup an event callback if the server signals that we should + // wait for more participants + this.server.on(type.WaitingForMoreParticipants, () => { + if (this.promiseForMoreParticipants !== undefined) + throw new Error("Server sent multiple WaitingForMoreParticipants messages") + debug(`[${shortenId(this.ownId)}] received WaitingForMoreParticipants message from server`) + // Display the waiting status right away + this.emit("status", "not enough participants") + // Upon receiving a WaitingForMoreParticipants message, + // the client will await for this promise to resolve before sending its + // local weight update + this.promiseForMoreParticipants = this.createPromiseForMoreParticipants() + }) + // As an example assume we need at least 2 participants to train, + // When two participants join almost at the same time, the server + // sends a NewNodeInfo with waitForMoreParticipants=true to the first participant + // and directly follows with an EnoughParticipants message when the 2nd participant joins + // However, the EnoughParticipants can arrive before the NewNodeInfo (which can be much bigger) + // so we check whether we received the EnoughParticipants before being assigned a node ID + this.server.once(type.EnoughParticipants, () => { + if (this._ownId === undefined) { + debug(`Received EnoughParticipants message from server before the NewFederatedNodeInfo message`) + setMessageInversionFlag() + } + }) + } + /** + * Method called when the server notifies the client that there aren't enough + * participants (anymore) to start/continue training + * The method creates a promise that will resolve once the server notifies + * the client that the training can resume via a subsequent EnoughParticipants message + * @returns a promise which resolves when enough participants joined the session + */ + protected async createPromiseForMoreParticipants(): Promise { + return new Promise((resolve) => { + // "once" is important because we can't resolve the same promise multiple times + this.server.once(type.EnoughParticipants, () => { + debug(`[${shortenId(this.ownId)}] received EnoughParticipants message from server`) + // Emit the last status emitted before waiting if defined + if (this.previousStatus !== undefined) this.emit("status", this.previousStatus) + resolve() + }) + }) + } + + protected async waitForParticipantsIfNeeded(): Promise{ + // we check if we are waiting for more participants before sending our weight update + if (this.waitingForMoreParticipants) { + // wait for the promise to resolve, which takes as long as it takes for new participants to join + debug(`[${shortenId(this.ownId)}] is awaiting the promise for more participants`) + this.emit("status", "not enough participants") + await this.promiseForMoreParticipants + // Make sure to set the promise back to undefined once resolved + this.promiseForMoreParticipants = undefined + } + } /** * Fetches the latest model available on the network's server, for the adequate task. * @returns The latest model @@ -63,24 +177,12 @@ export abstract class Client extends EventEmitter<{'status': RoundStatus}>{ } /** - * Communication callback called at the beginning of every training round. - */ - abstract onRoundBeginCommunication(): Promise; - - /** - * Communication callback called the end of every training round. - * @param weights The local weight update resulting for the current local training round - * @returns aggregated weights or the local weights upon error - */ - abstract onRoundEndCommunication(weights: WeightsContainer): Promise; - - // Number of contributors to a collaborative session - // If decentralized, it should be the number of peers - // If federated, it should the number of participants excluding the server - // If local it should be 1 - get nbOfParticipants(): number { - return this.aggregator.nodes.size // overriden by the federated client - } + * Number of contributors to a collaborative session + * If decentralized, it should be the number of peers + * If federated, it should the number of participants excluding the server + * If local it should be 1 + */ + abstract getNbOfParticipants(): number; get ownId(): NodeID { if (this._ownId === undefined) { @@ -95,4 +197,16 @@ export abstract class Client extends EventEmitter<{'status': RoundStatus}>{ } return this._server } + /** + * Whether the client should wait until more + * participants join the session, i.e. a promise has been created + */ + get waitingForMoreParticipants(): boolean { + return this.promiseForMoreParticipants !== undefined + } + +} + +export function shortenId(id: string): string { + return id.slice(0, 4) } diff --git a/discojs/src/client/decentralized/decentralized_client.ts b/discojs/src/client/decentralized/decentralized_client.ts index d68f3321f..f89b6c572 100644 --- a/discojs/src/client/decentralized/decentralized_client.ts +++ b/discojs/src/client/decentralized/decentralized_client.ts @@ -3,10 +3,11 @@ import { Map, Set } from 'immutable' import type { Model, WeightsContainer } from "../../index.js"; import { serialization } from "../../index.js"; -import { Client, type NodeID } from '../index.js' +import { Client, shortenId } from '../client.js' +import { type NodeID } from '../index.js' import { type, type ClientConnected } from '../messages.js' import { timeout } from '../utils.js' -import { type EventConnection, WebSocketServer, waitMessage, type PeerConnection, waitMessageWithTimeout } from '../event_connection.js' +import { WebSocketServer, waitMessage, type PeerConnection, waitMessageWithTimeout } from '../event_connection.js' import { PeerPool } from './peer_pool.js' import * as messages from './messages.js' @@ -22,9 +23,14 @@ export class DecentralizedClient extends Client { /** * The pool of peers to communicate with during the current training round. */ - private pool?: PeerPool - private connections?: Map - + #pool?: PeerPool + #connections?: Map + + override getNbOfParticipants(): number { + const nbOfParticipants = this.aggregator.nodes.size + return nbOfParticipants === 0 ? 1 : nbOfParticipants + } + // Used to handle timeouts and promise resolving after calling disconnect private get isDisconnected() : boolean { return this._server === undefined @@ -38,7 +44,7 @@ export class DecentralizedClient extends Client { * create peer-to-peer WebRTC connections with peers. The server is used to exchange * peers network information. */ - async connect(): Promise { + override async connect(): Promise { const model = await super.connect() // Get the server base model const serverURL = new URL('', this.url.href) switch (this.url.protocol) { @@ -52,52 +58,55 @@ export class DecentralizedClient extends Client { throw new Error(`unknown protocol: ${this.url.protocol}`) } serverURL.pathname += `decentralized/${this.task.id}` + // Create a WebSocket connection with the server + // The client then waits for the server to forward it other client's network information. + // Upon receiving other peer's information, the clients establish a peer-to-peer WebRTC connection. + this._server = await WebSocketServer.connect(serverURL, messages.isMessageFromServer, messages.isMessageToServer) + this.server.on(type.SignalForPeer, (event) => { + if (this.#pool === undefined) throw new Error('received signal but peer pool is undefined') + // Create a WebRTC connection with the peer + this.#pool.signal(event.peer, event.signal) + }) - this._server = await this.connectServer(serverURL) - + // c.f. setupServerCallbacks doc for explanation + let receivedEnoughParticipants = false + this.setupServerCallbacks(() => receivedEnoughParticipants = true) + const msg: ClientConnected = { type: type.ClientConnected } this.server.send(msg) - const { id } = await waitMessage(this.server, type.NewDecentralizedNodeInfo) - debug(`[${id}] assigned id generated by server`); + const { id, waitForMoreParticipants } = await waitMessage(this.server, type.NewDecentralizedNodeInfo) + + // This should come right after receiving the message to make sure + // we don't miss a subsequent message from the server + // We check if the server is telling us to wait for more participants + // and we also check if a EnoughParticipant message ended up arriving + // before the NewNodeInfo + if (waitForMoreParticipants && !receivedEnoughParticipants) { + // Create a promise that resolves when enough participants join + // The client will await this promise before sending its local weight update + this.promiseForMoreParticipants = this.createPromiseForMoreParticipants() + } + + debug(`[${shortenId(id)}] assigned id generated by server`); if (this._ownId !== undefined) { throw new Error('received id from server but was already received') } this._ownId = id - this.pool = new PeerPool(id) - + this.#pool = new PeerPool(id) return model } - /** - * Create a WebSocket connection with the server - * The client then waits for the server to forward it other client's network information. - * Upon receiving other peer's information, the clients establish a peer-to-peer WebRTC connection. - */ - private async connectServer (url: URL): Promise { - const server: EventConnection = await WebSocketServer.connect(url, messages.isMessageFromServer, messages.isMessageToServer) - - server.on(type.SignalForPeer, (event) => { - if (this.pool === undefined) { - throw new Error('received signal but peer pool is undefined') - } - // Create a WebRTC connection with the peer - this.pool.signal(event.peer, event.signal) - }) - - return server - } - - async disconnect (): Promise { + override async disconnect (): Promise { // Disconnect from peers - await this.pool?.shutdown() - this.pool = undefined + await this.#pool?.shutdown() + this.#pool = undefined - if (this.connections !== undefined) { - const peers = this.connections.keySeq().toSet() + if (this.#connections !== undefined) { + const peers = this.#connections.keySeq().toSet() this.aggregator.setNodes(this.aggregator.nodes.subtract(peers)) } // Disconnect from server @@ -116,27 +125,54 @@ export class DecentralizedClient extends Client { * and waits for it to resolve. * */ - override async onRoundBeginCommunication (): Promise { + override async onRoundBeginCommunication(): Promise { + // Notify the server we want to join the next round so that the server + // waits for us to be ready before sending the list of peers for the round + this.server.send({ type: type.JoinRound }) + // Store the promise for the current round's aggregation result. + // We will await for it to resolve at the end of the round when exchanging weight updates. + this.aggregationResult = this.aggregator.getPromiseForAggregation() + this.saveAndEmit("local training") + return Promise.resolve() + } + + override async onRoundEndCommunication (weights: WeightsContainer): Promise { + if (this.aggregationResult === undefined) { + throw new TypeError('aggregation result promise is undefined') + } + // Save the status in case participants leave and we switch to waiting for more participants + // Once enough new participants join we can display the previous status again + this.saveAndEmit("connecting to peers") + // First we check if we are waiting for more participants before sending our weight update + await this.waitForParticipantsIfNeeded() + // Create peer-to-peer connections with all peers for the round + await this.establishPeerConnections() + // Exchange weight updates with peers and return aggregated weights + return await this.exchangeWeightUpdates(weights) + } + + /** + * Signal to the server that we are ready to exchange weights. + * Once enough peers are ready, the server sends the list of peers for this round + * and the peers can establish peer-to-peer connections with each other. + */ + private async establishPeerConnections(): Promise { if (this.server === undefined) { throw new Error("peer's server is undefined, make sure to call `client.connect()` first") - } if (this.pool === undefined) { + } if (this.#pool === undefined) { throw new Error('peer pool is undefined, make sure to call `client.connect()` first') } - this.emit("status", "Retrieving peers' information") + // Reset peers list at each round of training to make sure client works with an updated peers // list, maintained by the server. Adds any received weights to the aggregator. - // Tell the server we are ready for the next round const readyMessage: messages.PeerIsReady = { type: type.PeerIsReady } this.server.send(readyMessage) // Wait for the server to answer with the list of peers for the round try { - debug(`[${this.ownId}] is waiting for peer list for round ${this.aggregator.round}`); - const receivedMessage = await waitMessageWithTimeout( - this.server, type.PeersForRound, undefined, - "Timeout waiting for the round's peer list" - ) + debug(`[${shortenId(this.ownId)}] is waiting for peer list for round ${this.aggregator.round}`); + const receivedMessage = await waitMessage(this.server, type.PeersForRound) const peers = Set(receivedMessage.peers) @@ -145,30 +181,25 @@ export class DecentralizedClient extends Client { } // Store the list of peers for the current round including ourselves this.aggregator.setNodes(peers.add(this.ownId)) + this.aggregator.setRound(receivedMessage.aggregationRound) // the server gives us the round number // Initiate peer to peer connections with each peer // When connected, create a promise waiting for each peer's round contribution - const connections = await this.pool.getPeers( + const connections = await this.#pool.getPeers( peers, this.server, - // Init receipt of peers weights - // this awaits the peer's weight update and adds it to - // our aggregator upon reception - (conn) => { this.receivePayloads(conn, this.aggregator.round) } + // Init receipt of peers weights. this awaits the peer's + // weight update and adds it to our aggregator upon reception + (conn) => this.receivePayloads(conn) ) - debug(`[${this.ownId}] received peers for round ${this.aggregator.round}: %o`, connections.keySeq().toJS()); - this.connections = connections + debug(`[${shortenId(this.ownId)}] received peers for round ${this.aggregator.round}: %o`, connections.keySeq().toJS()); + this.#connections = connections } catch (e) { - debug(`Error for [${this.ownId}] while beginning round: %o`, e); + debug(`Error for [${shortenId(this.ownId)}] while beginning round: %o`, e); this.aggregator.setNodes(Set(this.ownId)) - this.connections = Map() + this.#connections = Map() } - - // Store the promise for the current round's aggregation result. - // We will await for it to resolve at the end of the round when exchanging weight updates. - this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve)) - this.emit("status", "Training the model on the data you connected") } /** @@ -177,66 +208,73 @@ export class DecentralizedClient extends Client { * @param connections * @param round */ - private receivePayloads (connections: Map, round: number): void { + private receivePayloads (connections: Map): void { connections.forEach(async (connection, peerId) => { - let currentCommunicationRounds = 0 debug(`waiting for peer ${peerId}`); - do { + for (let r = 0; r < this.aggregator.communicationRounds; r++) { try { const message = await waitMessageWithTimeout(connection, type.Payload, 60_000, "Timeout waiting for a contribution from peer " + peerId) const decoded = serialization.weights.decode(message.payload) - if (!this.aggregator.add(peerId, decoded, round, message.round)) { - debug(`[${this.ownId}] failed to add contribution from peer ${peerId}`); + if (!this.aggregator.isValidContribution(peerId, message.aggregationRound)) { + debug(`[${shortenId(this.ownId)}] failed to add contribution from peer ${shortenId(peerId)}`); + } + else { + debug(`[${shortenId(this.ownId)}] received payload from peer ${shortenId(peerId)}` + + ` for round (%d, %d)`, message.aggregationRound, message.communicationRound); + this.aggregator.once("aggregation", () => + debug(`[${shortenId(this.ownId)}] aggregated the model` + + ` for round (%d, %d)`, message.aggregationRound, message.communicationRound) + ) + this.aggregator.add(peerId, decoded, message.aggregationRound, message.communicationRound) } } catch (e) { if (this.isDisconnected) return - debug(`Error for [${this.ownId}] while receiving payloads: %o`, e); + debug(`Error for [${shortenId(this.ownId)}] while receiving payloads: %o`, e); } - } while (++currentCommunicationRounds < this.aggregator.communicationRounds) + } }) } - override async onRoundEndCommunication (weights: WeightsContainer): Promise { + private async exchangeWeightUpdates(weights: WeightsContainer): Promise { if (this.aggregationResult === undefined) { throw new TypeError('aggregation result promise is undefined') } - this.emit("status", "Updating the model with other participants' models") - + this.saveAndEmit("updating model") // Perform the required communication rounds. Each communication round consists in sending our local payload, // followed by an aggregation step triggered by the receipt of other payloads, and handled by the aggregator. // A communication round's payload is the aggregation result of the previous communication round. The first // communication round simply sends our training result, i.e. model weights updates. This scheme allows for // the aggregator to define any complex multi-round aggregation mechanism. let result = weights; - for (let r = 0; r < this.aggregator.communicationRounds; r++) { + for (let communicationRound = 0; communicationRound < this.aggregator.communicationRounds; communicationRound++) { + const connections = this.#connections + if (connections === undefined) throw new Error("peer's connections is undefined") // Generate our payloads for this communication round and send them to all ready connected peers - if (this.connections !== undefined) { - const payloads = this.aggregator.makePayloads(result) - try { - await Promise.all(payloads.map(async (payload, id) => { - if (id === this.ownId) { - this.aggregator.add(this.ownId, payload, this.aggregator.round, r) - } else { - const peer = this.connections?.get(id) - if (peer !== undefined) { - const encoded = await serialization.weights.encode(payload) - const msg: messages.PeerMessage = { - type: type.Payload, - peer: id, - round: r, - payload: encoded - } - peer.send(msg) - debug(`[${this.ownId}] send weight update to peer ${msg.peer}: %O`, msg); - } - } - })) - } catch (cause) { - throw new Error('error while sending weights', { cause }) + const payloads = this.aggregator.makePayloads(result) + payloads.forEach(async (payload, id) => { + // add our own contribution to the aggregator + if (id === this.ownId) { + this.aggregator.add(this.ownId, payload, this.aggregator.round, communicationRound) + return } - } + // Send our payload to each peer + const peer = connections.get(id) + if (peer !== undefined) { + const encoded = await serialization.weights.encode(payload) + const msg: messages.PeerMessage = { + type: type.Payload, + peer: id, + aggregationRound: this.aggregator.round, + communicationRound, + payload: encoded + } + peer.send(msg) + debug(`[${shortenId(this.ownId)}] send weight update to peer ${shortenId(msg.peer)}` + + ` for round (%d, %d)`, this.aggregator.round, communicationRound); + } + }) // Wait for aggregation before proceeding to the next communication round. // The current result will be used as payload for the eventual next communication round. try { @@ -248,20 +286,16 @@ export class DecentralizedClient extends Client { if (this.isDisconnected) { return weights } - debug(`[${this.ownId}] while waiting for aggregation: %o`, e); + debug(`[${shortenId(this.ownId)}] while waiting for aggregation: %o`, e); break } // There is at least one communication round remaining - if (r < this.aggregator.communicationRounds - 1) { + if (communicationRound < this.aggregator.communicationRounds - 1) { // Reuse the aggregation result - this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve)) + this.aggregationResult = this.aggregator.getPromiseForAggregation() } } - - // Reset the peers list for the next round - this.aggregator.resetNodes() - return await this.aggregationResult } } diff --git a/discojs/src/client/decentralized/messages.ts b/discojs/src/client/decentralized/messages.ts index a9dfcd374..b282c256c 100644 --- a/discojs/src/client/decentralized/messages.ts +++ b/discojs/src/client/decentralized/messages.ts @@ -1,7 +1,8 @@ import { weights } from '../../serialization/index.js' import { type SignalData } from './peer.js' import { isNodeID, type NodeID } from '../types.js' -import { type, type ClientConnected, hasMessageType } from '../messages.js' +import { type, hasMessageType } from '../messages.js' +import type { ClientConnected, WaitingForMoreParticipants, EnoughParticipants } from '../messages.js' /// Phase 0 communication (between server and peers) @@ -18,15 +19,21 @@ export interface SignalForPeer { signal: SignalData } -// client who sent is ready +// peer wants to join the next round +export interface JoinRound { + type: type.JoinRound +} + +// peer who sent is ready export interface PeerIsReady { type: type.PeerIsReady } -// server send to client who to connect to +// server sends to each peer the list of peers to connect to export interface PeersForRound { type: type.PeersForRound peers: NodeID[] + aggregationRound: number } /// Phase 1 communication (between peers) @@ -34,7 +41,8 @@ export interface PeersForRound { export interface Payload { type: type.Payload peer: NodeID - round: number + aggregationRound: number + communicationRound: number payload: weights.Encoded } @@ -43,19 +51,20 @@ export interface Payload { export type MessageFromServer = NewDecentralizedNodeInfo | SignalForPeer | - PeersForRound + PeersForRound | + WaitingForMoreParticipants | + EnoughParticipants export type MessageToServer = ClientConnected | SignalForPeer | - PeerIsReady + PeerIsReady | + JoinRound export type PeerMessage = Payload export function isMessageFromServer (o: unknown): o is MessageFromServer { - if (!hasMessageType(o)) { - return false - } + if (!hasMessageType(o)) return false switch (o.type) { case type.NewDecentralizedNodeInfo: @@ -67,15 +76,16 @@ export function isMessageFromServer (o: unknown): o is MessageFromServer { 'signal' in o // TODO check signal content? case type.PeersForRound: return 'peers' in o && Array.isArray(o.peers) && o.peers.every(isNodeID) + case type.WaitingForMoreParticipants: + case type.EnoughParticipants: + return true } return false } export function isMessageToServer (o: unknown): o is MessageToServer { - if (!hasMessageType(o)) { - return false - } + if (!hasMessageType(o)) return false switch (o.type) { case type.ClientConnected: @@ -83,6 +93,7 @@ export function isMessageToServer (o: unknown): o is MessageToServer { case type.SignalForPeer: return 'peer' in o && isNodeID(o.peer) && 'signal' in o // TODO check signal content? + case type.JoinRound: case type.PeerIsReady: return true } @@ -91,9 +102,7 @@ export function isMessageToServer (o: unknown): o is MessageToServer { } export function isPeerMessage (o: unknown): o is PeerMessage { - if (!hasMessageType(o)) { - return false - } + if (!hasMessageType(o)) return false switch (o.type) { case type.Payload: diff --git a/discojs/src/client/decentralized/peer_pool.spec.ts b/discojs/src/client/decentralized/peer_pool.spec.ts index 0a5c88f0d..ed3e2be9a 100644 --- a/discojs/src/client/decentralized/peer_pool.spec.ts +++ b/discojs/src/client/decentralized/peer_pool.spec.ts @@ -45,7 +45,8 @@ describe('peer pool', function () { type: type.Payload, peer: id, payload: [1, 2, 3], - round: 0 + aggregationRound: 0, + communicationRound: 0 } } diff --git a/discojs/src/client/federated/federated_client.ts b/discojs/src/client/federated/federated_client.ts index 24b670b1f..556a414ec 100644 --- a/discojs/src/client/federated/federated_client.ts +++ b/discojs/src/client/federated/federated_client.ts @@ -1,11 +1,10 @@ import createDebug from "debug"; import { serialization } from "../../index.js"; -import type { Model, RoundStatus, WeightsContainer } from "../../index.js"; -import { Client } from "../client.js"; +import type { Model, WeightsContainer } from "../../index.js"; +import { Client, shortenId } from "../client.js"; import { type, type ClientConnected } from "../messages.js"; import { - type EventConnection, waitMessage, waitMessageWithTimeout, WebSocketServer, @@ -29,53 +28,18 @@ export class FederatedClient extends Client { // Total number of other federated contributors, including this client, excluding the server // E.g., if 3 users are training a federated model, nbOfParticipants is 3 #nbOfParticipants: number = 1; - - /** - * When the server notifies clients to pause and wait until more - * participants join, we rely on this promise to wait - * until the server signals that the training can resume - */ - #promiseForMoreParticipants: Promise | undefined = undefined; - - /** - * When the server notifies the client that they can resume training - * after waiting for more participants, we want to be able to display what - * we were doing before waiting (training locally or updating our model). - * We use this attribute to store the status to rollback to when we stop waiting - */ - #previousStatus: RoundStatus | undefined = undefined; - - /** - * Whether the client should wait until more - * participants join the session, i.e. a promise has been created - */ - get #waitingForMoreParticipants(): boolean { - return this.#promiseForMoreParticipants !== undefined - } // the number of participants excluding the server - override get nbOfParticipants(): number { + override getNbOfParticipants(): number { return this.#nbOfParticipants } - /** - * Opens a new WebSocket connection with the server and listens to new messages over the channel - */ - private async connectServer(url: URL): Promise { - const server: EventConnection = await WebSocketServer.connect( - url, - messages.isMessageFederated, // can only receive federated message types from the server - messages.isMessageFederated, // idem for messages that the client can send - ); - return server; - } - /** * Initializes the connection to the server, gets our node ID * as well as the latest training information: latest global model, current round and * whether we are waiting for more participants. */ - async connect(): Promise { + override async connect(): Promise { const model = await super.connect() // Get the server base model const serverURL = new URL("", this.url.href); @@ -89,36 +53,17 @@ export class FederatedClient extends Client { default: throw new Error(`unknown protocol: ${this.url.protocol}`); } - serverURL.pathname += `federated/${this.task.id}`; + // Opens a new WebSocket connection with the server and listens to new messages over the channel + this._server = await WebSocketServer.connect( + serverURL, + messages.isMessageFederated, // can only receive federated message types from the server + messages.isMessageFederated, // idem for messages that the client can send + ); - this._server = await this.connectServer(serverURL); - - // Setup an event callback if the server signals that we should - // wait for more participants - this.server.on(type.WaitingForMoreParticipants, () => { - debug(`[${id.slice(0, 4)}] received WaitingForMoreParticipants message from server`) - // Display the waiting status right away - this.emit("status", "Waiting for more participants") - // Upon receiving a WaitingForMoreParticipants message, - // the client will await for this promise to resolve before sending its - // local weight update - this.#promiseForMoreParticipants = this.waitForMoreParticipants() - }) - - // As an example assume we need at least 2 participants to train, - // When two participants join almost at the same time, the server - // sends a NewFederatedNodeInfo with waitForMoreParticipants=true to the first participant - // and directly follows with an EnoughParticipants message when the 2nd participant joins - // However, the EnoughParticipants can arrive before the NewFederatedNodeInfo (which is much bigger) - // so we check whether we received the EnoughParticipants before being assigned a node ID + // c.f. setupServerCallbacks doc for explanation let receivedEnoughParticipants = false - this.server.once(type.EnoughParticipants, () => { - if (this._ownId === undefined) { - debug(`Received EnoughParticipants message from server before the NewFederatedNodeInfo message`) - receivedEnoughParticipants = true - } - }) + this.setupServerCallbacks(() => receivedEnoughParticipants = true) this.aggregator.registerNode(SERVER_NODE_ID); @@ -140,41 +85,22 @@ export class FederatedClient extends Client { if (waitForMoreParticipants && !receivedEnoughParticipants) { // Create a promise that resolves when enough participants join // The client will await this promise before sending its local weight update - this.#promiseForMoreParticipants = this.waitForMoreParticipants() + this.promiseForMoreParticipants = this.createPromiseForMoreParticipants() } if (this._ownId !== undefined) { throw new Error('received id from server but was already received') } this._ownId = id; - debug(`[${id.slice(0, 4)}] joined session at round ${round} `); + debug(`[${shortenId(id)}] joined session at round ${round} `); this.aggregator.setRound(round) this.#nbOfParticipants = nbOfParticipants // Upon connecting, the server answers with a boolean // which indicates whether there are enough participants or not - debug(`[${this.ownId.slice(0, 4)}] upon connecting, wait for participant flag %o`, this.#waitingForMoreParticipants) + debug(`[${shortenId(this.ownId)}] upon connecting, wait for participant flag %o`, this.waitingForMoreParticipants) model.weights = serialization.weights.decode(payload) return model } - /** - * Method called when the server notifies the client that there aren't enough - * participants (anymore) to start/continue training - * The method creates a promise that will resolve once the server notifies - * the client that the training can resume via a subsequent EnoughParticipants message - * @returns a promise which resolves when enough participants joined the session - */ - private async waitForMoreParticipants(): Promise { - return new Promise((resolve) => { - // "once" is important because we can't resolve the same promise multiple times - this.server.once(type.EnoughParticipants, () => { - debug(`[${this.ownId.slice(0, 4)}] received EnoughParticipants message from server`) - // Emit the last status emitted before waiting if defined - if (this.#previousStatus !== undefined) this.emit("status", this.#previousStatus) - resolve() - }) - }) - } - /** * Disconnection process when user quits the task. */ @@ -189,10 +115,7 @@ export class FederatedClient extends Client { override onRoundBeginCommunication(): Promise { // Prepare the result promise for the incoming round this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve)) - // Save the status in case participants leave and we switch to waiting for more participants - // Once enough new participants join we can display the previous status again - this.#previousStatus = "Training the model on the data you connected" - this.emit("status", this.#previousStatus) + this.saveAndEmit("local training") return Promise.resolve(); } @@ -212,19 +135,8 @@ export class FederatedClient extends Client { } // First we check if we are waiting for more participants before sending our weight update - if (this.#waitingForMoreParticipants) { - // wait for the promise to resolve, which takes as long as it takes for new participants to join - debug(`[${this.ownId.slice(0, 4)}] is awaiting the promise for more participants`) - this.emit("status", "Waiting for more participants") - await this.#promiseForMoreParticipants - // Make sure to set the promise back to undefined once resolved - this.#promiseForMoreParticipants = undefined - } - // Save the status in case participants leave and we switch to waiting for more participants - // Once enough new participants join we can display the previous status again - this.#previousStatus = "Updating the model with other participants' models" - this.emit("status", this.#previousStatus) - + await this.waitForParticipantsIfNeeded() + this.saveAndEmit("updating model") // Send our local contribution to the server // and receive the server global update for this round as an answer to our contribution const payloadToServer: WeightsContainer = this.aggregator.makePayloads(weights).first() @@ -236,9 +148,9 @@ export class FederatedClient extends Client { // Need to await the resulting global model right after sending our local contribution // to make sure we don't miss it - debug(`[${this.ownId.slice(0, 4)}] sent its local update to the server for round ${this.aggregator.round}`); + debug(`[${shortenId(this.ownId)}] sent its local update to the server for round ${this.aggregator.round}`); this.server.send(msg); - debug(`[${this.ownId.slice(0, 4)}] is waiting for server update for round ${this.aggregator.round + 1}`); + debug(`[${shortenId(this.ownId)}] is waiting for server update for round ${this.aggregator.round + 1}`); const { payload: payloadFromServer, round: serverRound, diff --git a/discojs/src/client/federated/messages.ts b/discojs/src/client/federated/messages.ts index 35b93ccf9..723ca704c 100644 --- a/discojs/src/client/federated/messages.ts +++ b/discojs/src/client/federated/messages.ts @@ -1,9 +1,8 @@ import { type weights } from '../../serialization/index.js' import { type NodeID } from '..//types.js' -import { - type, hasMessageType, type ClientConnected - } from '../messages.js' +import { type, hasMessageType } from '../messages.js' + import type { ClientConnected, WaitingForMoreParticipants, EnoughParticipants } from '../messages.js' // See ../messages.ts for doc export type MessageFederated = @@ -35,14 +34,6 @@ export interface ReceiveServerPayload { nbOfParticipants: number // number of peers contributing to a federated training } -export interface EnoughParticipants { - type: type.EnoughParticipants -} - -export interface WaitingForMoreParticipants { - type: type.WaitingForMoreParticipants -} - export function isMessageFederated (raw: unknown): raw is MessageFederated { if (!hasMessageType(raw)) { return false diff --git a/discojs/src/client/local_client.ts b/discojs/src/client/local_client.ts index e927364f8..01b408807 100644 --- a/discojs/src/client/local_client.ts +++ b/discojs/src/client/local_client.ts @@ -6,11 +6,16 @@ import { Client } from "./client.js"; * with anyone. Thus LocalClient doesn't do anything during communication */ export class LocalClient extends Client { - onRoundBeginCommunication(): Promise { + + override getNbOfParticipants(): number { + return 1; + } + + override onRoundBeginCommunication(): Promise { return Promise.resolve(); } // Simply return the local weights - onRoundEndCommunication(weights: WeightsContainer): Promise { + override onRoundEndCommunication(weights: WeightsContainer): Promise { return Promise.resolve(weights); } } \ No newline at end of file diff --git a/discojs/src/client/messages.ts b/discojs/src/client/messages.ts index ff54e6bb9..cbb47aee9 100644 --- a/discojs/src/client/messages.ts +++ b/discojs/src/client/messages.ts @@ -12,15 +12,18 @@ export enum type { // answers with its peer id and also tells the client whether we are waiting // for more participants before starting training NewDecentralizedNodeInfo, - // Message forwarded by the server from a client to another client - // to establish a peer-to-peer (WebRTC) connection - SignalForPeer, + // Message sent by peers to the server to signal they want to + // join the next round + JoinRound, // Message sent by nodes to server signaling they are ready to // start the next round PeerIsReady, // Sent by the server to participating peers containing the list // of peers for the round PeersForRound, + // Message forwarded by the server from a client to another client + // to establish a peer-to-peer (WebRTC) connection + SignalForPeer, // The weight update Payload, @@ -42,6 +45,14 @@ export interface ClientConnected { type: type.ClientConnected } +export interface EnoughParticipants { + type: type.EnoughParticipants +} + +export interface WaitingForMoreParticipants { + type: type.WaitingForMoreParticipants +} + export type Message = decentralized.MessageFromServer | decentralized.MessageToServer | @@ -52,9 +63,7 @@ export type Message = export type NarrowMessage = Extract export function hasMessageType (raw: unknown): raw is { type: type } & Record { - if (typeof raw !== 'object' || raw === null) { - return false - } + if (typeof raw !== 'object' || raw === null) return false const o = raw as Record if ( diff --git a/discojs/src/dataset/data/image_data.ts b/discojs/src/dataset/data/image_data.ts index e4c771a02..c45639cc3 100644 --- a/discojs/src/dataset/data/image_data.ts +++ b/discojs/src/dataset/data/image_data.ts @@ -11,7 +11,7 @@ import { ImagePreprocessing, IMAGE_PREPROCESSING } from './preprocessing/index.j export class ImageData extends Data { public readonly availablePreprocessing = IMAGE_PREPROCESSING - static async init ( + static override async init ( dataset: tf.data.Dataset, task: Task, size?: number diff --git a/discojs/src/dataset/data/tabular_data.ts b/discojs/src/dataset/data/tabular_data.ts index cf8f119af..0a3f45651 100644 --- a/discojs/src/dataset/data/tabular_data.ts +++ b/discojs/src/dataset/data/tabular_data.ts @@ -11,7 +11,7 @@ import { TABULAR_PREPROCESSING } from './preprocessing/index.js' export class TabularData extends Data { public readonly availablePreprocessing = TABULAR_PREPROCESSING - static async init ( + static override async init ( dataset: tf.data.Dataset, task: Task, size?: number diff --git a/discojs/src/dataset/data/text_data.ts b/discojs/src/dataset/data/text_data.ts index b64b45740..daaee0848 100644 --- a/discojs/src/dataset/data/text_data.ts +++ b/discojs/src/dataset/data/text_data.ts @@ -11,7 +11,7 @@ import { TEXT_PREPROCESSING } from './preprocessing/index.js' export class TextData extends Data { public readonly availablePreprocessing = TEXT_PREPROCESSING - static init ( + static override init ( dataset: tf.data.Dataset, task: Task, size?: number diff --git a/discojs/src/default_tasks/cifar10.ts b/discojs/src/default_tasks/cifar10.ts index b9598e561..6c5f22235 100644 --- a/discojs/src/default_tasks/cifar10.ts +++ b/discojs/src/default_tasks/cifar10.ts @@ -33,8 +33,8 @@ export const cifar10: TaskProvider = { IMAGE_W: 224, LABEL_LIST: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'], scheme: 'decentralized', + aggregationStrategy: 'mean', privacy: { clippingRadius: 20, noiseScale: 1 }, - decentralizedSecure: true, minNbOfParticipants: 3, maxShareValue: 100, tensorBackend: 'tfjs' diff --git a/discojs/src/default_tasks/lus_covid.ts b/discojs/src/default_tasks/lus_covid.ts index 356537a01..77889eb48 100644 --- a/discojs/src/default_tasks/lus_covid.ts +++ b/discojs/src/default_tasks/lus_covid.ts @@ -31,6 +31,7 @@ export const lusCovid: TaskProvider = { LABEL_LIST: ['COVID-Positive', 'COVID-Negative'], dataType: 'image', scheme: 'federated', + aggregationStrategy: 'mean', minNbOfParticipants: 2, tensorBackend: 'tfjs' } diff --git a/discojs/src/default_tasks/mnist.ts b/discojs/src/default_tasks/mnist.ts index 7c3f4b517..a55c2df7d 100644 --- a/discojs/src/default_tasks/mnist.ts +++ b/discojs/src/default_tasks/mnist.ts @@ -32,7 +32,7 @@ export const mnist: TaskProvider = { preprocessingFunctions: [data.ImagePreprocessing.Resize, data.ImagePreprocessing.Normalize], LABEL_LIST: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], scheme: 'decentralized', - decentralizedSecure: true, + aggregationStrategy: 'secure', minNbOfParticipants: 3, maxShareValue: 100, tensorBackend: 'tfjs' diff --git a/discojs/src/default_tasks/simple_face.ts b/discojs/src/default_tasks/simple_face.ts index ec8453179..e18d5c726 100644 --- a/discojs/src/default_tasks/simple_face.ts +++ b/discojs/src/default_tasks/simple_face.ts @@ -31,6 +31,7 @@ export const simpleFace: TaskProvider = { IMAGE_W: 200, LABEL_LIST: ['child', 'adult'], scheme: 'federated', + aggregationStrategy: 'mean', minNbOfParticipants: 2, tensorBackend: 'tfjs' } diff --git a/discojs/src/default_tasks/titanic.ts b/discojs/src/default_tasks/titanic.ts index a2fc6b28f..393ee01c8 100644 --- a/discojs/src/default_tasks/titanic.ts +++ b/discojs/src/default_tasks/titanic.ts @@ -65,6 +65,7 @@ export const titanic: TaskProvider = { 'Survived' ], scheme: 'federated', + aggregationStrategy: 'mean', minNbOfParticipants: 2, tensorBackend: 'tfjs' } diff --git a/discojs/src/default_tasks/wikitext.ts b/discojs/src/default_tasks/wikitext.ts index a94b4207f..518c060c2 100644 --- a/discojs/src/default_tasks/wikitext.ts +++ b/discojs/src/default_tasks/wikitext.ts @@ -27,6 +27,7 @@ export const wikitext: TaskProvider = { dataType: 'text', preprocessingFunctions: [data.TextPreprocessing.Tokenize, data.TextPreprocessing.LeftPadding], scheme: 'federated', + aggregationStrategy: 'mean', minNbOfParticipants: 2, epochs: 6, // Unused by wikitext because data already comes split diff --git a/discojs/src/models/gpt/layers.ts b/discojs/src/models/gpt/layers.ts index 92578713e..7c0bc6fd5 100644 --- a/discojs/src/models/gpt/layers.ts +++ b/discojs/src/models/gpt/layers.ts @@ -8,11 +8,11 @@ import type { ModelSize } from './config.js' class Range extends tf.layers.Layer { static readonly className = 'Range' - computeOutputShape (inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] { + override computeOutputShape (inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] { return inputShape } - call (input: tf.Tensor | tf.Tensor[], kwargs: Record): tf.Tensor | tf.Tensor[] { + override call (input: tf.Tensor | tf.Tensor[], kwargs: Record): tf.Tensor | tf.Tensor[] { return tf.tidy(() => { if (Array.isArray(input)) { // TODO support multitensor @@ -30,11 +30,11 @@ tf.serialization.registerClass(Range) class LogLayer extends tf.layers.Layer { static readonly className = 'LogLayer' - computeOutputShape (inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] { + override computeOutputShape (inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] { return inputShape } - call (input: tf.Tensor | tf.Tensor[], kwargs: Record): tf.Tensor | tf.Tensor[] { + override call (input: tf.Tensor | tf.Tensor[], kwargs: Record): tf.Tensor | tf.Tensor[] { return tf.tidy(() => { if (Array.isArray(input)) { input = input[0] @@ -77,7 +77,7 @@ class CausalSelfAttention extends tf.layers.Layer { this.mask = tf.linalg.bandPart(tf.ones([config.blockSize, config.blockSize]), -1, 0) } - build (): void { + override build (): void { this.cAttnKernel = this.addWeight( 'c_attn/kernel', [this.nEmbd, 3 * this.nEmbd], @@ -104,16 +104,16 @@ class CausalSelfAttention extends tf.layers.Layer { ) } - computeOutputShape (inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] { + override computeOutputShape (inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] { return inputShape } - getConfig (): tf.serialization.ConfigDict { + override getConfig (): tf.serialization.ConfigDict { const config = super.getConfig() return Object.assign({}, config, this.config) } - call (input: tf.Tensor | tf.Tensor[], kwargs: Record): tf.Tensor | tf.Tensor[] { + override call (input: tf.Tensor | tf.Tensor[], kwargs: Record): tf.Tensor | tf.Tensor[] { return tf.tidy(() => { if (this.cAttnKernel === undefined || this.cAttnBias === undefined || @@ -199,11 +199,11 @@ class GELU extends tf.layers.Layer { super({}) } - computeOutputShape (inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] { + override computeOutputShape (inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] { return inputShape } - call (input: tf.Tensor | tf.Tensor[], kwargs: Record): tf.Tensor | tf.Tensor[] { + override call (input: tf.Tensor | tf.Tensor[], kwargs: Record): tf.Tensor | tf.Tensor[] { return tf.tidy(() => { if (Array.isArray(input)) { // TODO support multitensor diff --git a/discojs/src/models/gpt/model.ts b/discojs/src/models/gpt/model.ts index 0c220a7b6..c6efcb5de 100644 --- a/discojs/src/models/gpt/model.ts +++ b/discojs/src/models/gpt/model.ts @@ -48,14 +48,14 @@ class GPTModel extends tf.LayersModel { return this.config } - compile() { + override compile() { if (this.optimizer !== undefined) return this.optimizer = this.config.weightDecay !== 0 ? getCustomAdam(this, this.config.lr, this.config.weightDecay) : tf.train.adam(this.config.lr) } - async fitDataset(dataset: Dataset, trainingArgs: tf.ModelFitDatasetArgs): Promise { + override async fitDataset(dataset: Dataset, trainingArgs: tf.ModelFitDatasetArgs): Promise { const callbacks = trainingArgs.callbacks as tf.CustomCallbackArgs const evalDataset = trainingArgs.validationData as tf.data.Dataset<{ xs: tf.Tensor2D, ys: tf.Tensor3D }> await callbacks.onTrainBegin?.() diff --git a/discojs/src/models/gpt/optimizers.ts b/discojs/src/models/gpt/optimizers.ts index 7509d51f2..245df999b 100644 --- a/discojs/src/models/gpt/optimizers.ts +++ b/discojs/src/models/gpt/optimizers.ts @@ -91,7 +91,7 @@ class AdamW extends tf.AdamOptimizer { this.gradientClipNorm = p.gradientClipNorm } - applyGradients (variableGradients: Record | Array<{ name: string, tensor: tf.Tensor }>): void { + override applyGradients (variableGradients: Record | Array<{ name: string, tensor: tf.Tensor }>): void { const varNames: string[] = Array.isArray(variableGradients) ? variableGradients.map((v) => v.name) : Object.keys(variableGradients) diff --git a/discojs/src/task/training_information.ts b/discojs/src/task/training_information.ts index 369bfe91f..137b6d57f 100644 --- a/discojs/src/task/training_information.ts +++ b/discojs/src/task/training_information.ts @@ -38,10 +38,6 @@ export interface TrainingInformation { // use Differential Privacy, reduce training accuracy and improve privacy. privacy?: Privacy; - - // decentralizedSecure: Secure Aggregation on/off: - // Boolean. true for secure aggregation to be used, if the training scheme is decentralized, false otherwise - decentralizedSecure?: boolean // maxShareValue: Secure Aggregation: maximum absolute value of a number in a randomly generated share // default is 100, must be a positive number, check the docs/PRIVACY.md file for more information on significance of maxShareValue selection // only relevant if secure aggregation is true (for either federated or decentralized learning) @@ -49,9 +45,9 @@ export interface TrainingInformation { // minNbOfParticipants: minimum number of participants required to train collaboratively // In decentralized Learning the default is 3, in federated learning it is 2 minNbOfParticipants: number - // aggregator: aggregator to be used by the server for federated learning, or by the peers for decentralized learning - // default is 'average', other options include for instance 'bandit' - aggregator?: 'mean' | 'secure' // TODO: never used + // aggregationStrategy: aggregator to be used by the server for federated learning, or by the peers for decentralized learning + // default is 'mean' + aggregationStrategy?: 'mean' | 'secure' // tokenizer (string | PreTrainedTokenizer). This field should be initialized with the name of a Transformers.js pre-trained tokenizer, e.g., 'Xenova/gpt2'. // When the tokenizer is first called, the actual object will be initialized and loaded into this field for the subsequent tokenizations. tokenizer?: string | PreTrainedTokenizer @@ -104,10 +100,9 @@ export function isTrainingInformation (raw: unknown): raw is TrainingInformation IMAGE_H, IMAGE_W, LABEL_LIST, - aggregator, + aggregationStrategy, batchSize, dataType, - decentralizedSecure, privacy, epochs, inputColumns, @@ -132,8 +127,7 @@ export function isTrainingInformation (raw: unknown): raw is TrainingInformation typeof minNbOfParticipants !== 'number' || (tokenizer !== undefined && typeof tokenizer !== 'string' && !(tokenizer instanceof PreTrainedTokenizer)) || (maxSequenceLength !== undefined && typeof maxSequenceLength !== 'number') || - (aggregator !== undefined && typeof aggregator !== 'string') || - (decentralizedSecure !== undefined && typeof decentralizedSecure !== 'boolean') || + (aggregationStrategy !== undefined && typeof aggregationStrategy !== 'string') || (privacy !== undefined && !isPrivacy(privacy)) || (maxShareValue !== undefined && typeof maxShareValue !== 'number') || (IMAGE_H !== undefined && typeof IMAGE_H !== 'number') || @@ -146,8 +140,8 @@ export function isTrainingInformation (raw: unknown): raw is TrainingInformation return false } - if (aggregator !== undefined) { - switch (aggregator) { + if (aggregationStrategy !== undefined) { + switch (aggregationStrategy) { case 'mean': break case 'secure': break default: return false @@ -161,7 +155,7 @@ export function isTrainingInformation (raw: unknown): raw is TrainingInformation default: return false } - // interdepences on data type + // interdependencies on data type if (dataType === 'image') { if (typeof IMAGE_H !== 'number' || typeof IMAGE_W !== 'number') { return false @@ -192,10 +186,9 @@ export function isTrainingInformation (raw: unknown): raw is TrainingInformation IMAGE_W, IMAGE_H, LABEL_LIST, - aggregator, + aggregationStrategy, batchSize, dataType, - decentralizedSecure, privacy, epochs, inputColumns, diff --git a/discojs/src/training/disco.ts b/discojs/src/training/disco.ts index 5dcb86367..274c109bb 100644 --- a/discojs/src/training/disco.ts +++ b/discojs/src/training/disco.ts @@ -22,11 +22,10 @@ interface DiscoConfig { logger: Logger; } -export type RoundStatus = - "Waiting for more participants" | - "Retrieving peers' information" | - "Updating the model with other participants' models" | - "Training the model on the data you connected" +export type RoundStatus = 'not enough participants' | // Server notification to wait for more participants + 'updating model' | // fetching/aggregating local updates into a global model + 'local training' | // Training the model locally + 'connecting to peers' // for decentralized only, fetch the server's list of participating peers /** * Top-level class handling distributed training from a client's perspective. It is meant to be diff --git a/discojs/src/training/trainer.ts b/discojs/src/training/trainer.ts index 3e367df61..4f885b8aa 100644 --- a/discojs/src/training/trainer.ts +++ b/discojs/src/training/trainer.ts @@ -120,7 +120,7 @@ export class Trainer { return { epochs: epochsLogs, - participants: this.#client.nbOfParticipants, + participants: this.#client.getNbOfParticipants(), }; } } diff --git a/discojs/src/utils/event_emitter.ts b/discojs/src/utils/event_emitter.ts index 3c93922bb..f6ea3c26d 100644 --- a/discojs/src/utils/event_emitter.ts +++ b/discojs/src/utils/event_emitter.ts @@ -2,7 +2,7 @@ import { List } from 'immutable' -type Listener = (_: T) => void +type Listener = (_: T) => void | Promise /** * Call handlers on given events @@ -10,7 +10,8 @@ type Listener = (_: T) => void * @typeParam I object/mapping from event name to emitted value type */ export class EventEmitter> { - private listeners: { + // List of callbacks to run per event + #listeners: { [E in keyof I]?: List<[once: boolean, _: Listener]>; } = {} @@ -31,14 +32,14 @@ export class EventEmitter> { } /** - * Register listener to call on event + * Register listener to call on event. * * @param event event name to listen to * @param listener handler to call */ on(event: E, listener: Listener): void { - const eventListeners = this.listeners[event] ?? List() - this.listeners[event] = eventListeners.push([false, listener]) + const eventListeners = this.#listeners[event] ?? List() + this.#listeners[event] = eventListeners.push([false, listener]) } /** @@ -48,8 +49,8 @@ export class EventEmitter> { * @param listener handler to call next time */ once(event: E, listener: Listener): void { - const eventListeners = this.listeners[event] ?? List() - this.listeners[event] = eventListeners.push([true, listener]) + const eventListeners = this.#listeners[event] ?? List() + this.#listeners[event] = eventListeners.push([true, listener]) } /** @@ -59,10 +60,10 @@ export class EventEmitter> { * @param value what to call listeners with */ emit(event: E, value: I[E]): void { - const eventListeners = this.listeners[event] ?? List() - this.listeners[event] = eventListeners.filterNot(([once]) => once) + const eventListeners = this.#listeners[event] ?? List() + this.#listeners[event] = eventListeners.filterNot(([once]) => once) - eventListeners.forEach(([_, listener]) => { listener(value) }) + eventListeners.forEach(async ([_, listener]) => { await listener(value) }) } } diff --git a/docs/FAQ.md b/docs/FAQ.md index d2beb00eb..a3fb332a9 100644 --- a/docs/FAQ.md +++ b/docs/FAQ.md @@ -1,5 +1,19 @@ # FAQ +### Local code changes are not taken into account + +Make sure to remember that `discojs`, `discojs-node` and `discojs-web` need to be re-built for code changes to be effective. You can automate this process by running `npm -w discojs run watch build` which watches for code changes to rebuild `discojs`. Similarly for `discojs-node` and `discojs-web`. + +If you are changing parameters of a default task in `discojs/default_task`, you also need to restart the server after re-building `discojs`. This is because the server initializes the tasks upon starting so later changes are not taken into account. + +In case of doubts, close everything, re-install dependencies (`npm ci`), re-build everything (`npm -ws run build`) and restart Disco. + +### Peers can't connect to each other in decentralized learning + +Make sure you are connected to internet, without any VPN. Indeed, WebRTC needs connection to reach an online server (the STUN server, `simple-peer` currently uses `stun.l.google.com:19302`) for peers to establish a direct connection. More information on WebRTC [here](https://developer.mozilla.org/en-US/docs/Web/API/WebRTC_API). + +You can troubleshoot the issue by trying [the simple use case](https://github.com/feross/simple-peer?tab=readme-ov-file#usage) of `simple-peer`. + ### Using TensorFlow.js on Mac laptops with M1 chips `TensorFlow.js` in version `3` currently supports M1 Mac laptops. However, make sure you have an `arm` Node.js executable installed (not `x86`). It can be checked using: diff --git a/docs/examples/custom_task.ts b/docs/examples/custom_task.ts index 911162017..fded9d788 100644 --- a/docs/examples/custom_task.ts +++ b/docs/examples/custom_task.ts @@ -63,12 +63,7 @@ const customTask: TaskProvider = { async function runServer (): Promise { // Create server - const server = await DiscoServer.of( - // with some tasks provided by Disco - defaultTasks.titanic, - // or your own custom task - customTask, - ) + const server = new DiscoServer() // You can also provide your own task object containing the URL of the model @@ -86,7 +81,10 @@ async function runServer (): Promise { // await server.addTask(customTask.getTask(), new URL('https://example.com/path/to/your/model.json')) // Start the server - await server.serve() + await server.serve(8080, + defaultTasks.titanic, // with some tasks provided by Disco + // or your own custom task + customTask,) } runServer().catch(console.error) diff --git a/docs/examples/training.ts b/docs/examples/training.ts index 1270882ad..d3600deb2 100644 --- a/docs/examples/training.ts +++ b/docs/examples/training.ts @@ -27,8 +27,7 @@ async function main (): Promise { const NAME: string = 'titanic' // Launch a server instance - const discoServer = await Server.of(defaultTasks.simpleFace, defaultTasks.titanic) - const [server, url] = await discoServer.serve() + const [server, url] = await new Server().serve(undefined, defaultTasks.simpleFace, defaultTasks.titanic) // Get all pre-defined tasks const tasks = await fetchTasks(url) diff --git a/server/src/controllers/decentralized_controller.ts b/server/src/controllers/decentralized_controller.ts index 933615a90..0367e1aac 100644 --- a/server/src/controllers/decentralized_controller.ts +++ b/server/src/controllers/decentralized_controller.ts @@ -2,7 +2,7 @@ import createDebug from "debug"; import { v4 as randomUUID } from 'uuid' import msgpack from 'msgpack-lite' import type WebSocket from 'ws' -import { Map, Set } from 'immutable' +import { Map } from 'immutable' import { client } from '@epfml/discojs' @@ -14,23 +14,22 @@ import MessageTypes = client.messages.type const debug = createDebug("server:controllers:decentralized") export class DecentralizedController extends TrainingController { - /** - * Set of nodes who have contributed. - */ - private readyNodes = Set() - /** - * Map associating node ids to their open WebSocket connections. - */ - private connections: Map = Map() + // Map of nodes who want to join the round. + // The boolean value indicates if the node is ready to exchange weight updates (i.e. + // the node has already sent a PeerIsReady message) + // We wait for all peers to be ready to exchange weight updates + #roundPeers = Map() + #aggregationRound = 0 handle (ws: WebSocket): void { - const minimumReadyPeers = this.task.trainingInformation.minNbOfParticipants + const minNbOfParticipants = this.task.trainingInformation.minNbOfParticipants // Peer id of the message sender let peerId = randomUUID() while (this.connections.has(peerId)) { peerId = randomUUID() } + const shortId = peerId.slice(0, 4) // How the server responds to messages ws.on('message', (data: Buffer) => { @@ -42,46 +41,32 @@ export class DecentralizedController extends TrainingController { switch (msg.type) { // A new peer joins the network for a task case MessageTypes.ClientConnected: { + debug(`peer [%s] joined ${this.task.id}`, shortId) this.connections = this.connections.set(peerId, ws) // Answer with client id in an NewNodeInfo message const msg: messages.NewDecentralizedNodeInfo = { type: MessageTypes.NewDecentralizedNodeInfo, id: peerId, - waitForMoreParticipants: this.readyNodes.size < minimumReadyPeers // ground work for #718 + waitForMoreParticipants: this.connections.size < minNbOfParticipants } - debug(`peer ${peerId} joined ${this.task.id}`); - ws.send(msgpack.encode(msg), { binary: true }) + // Send an update to participants if we can start/resume training + this.sendEnoughParticipantsMsgIfNeeded(peerId) break } - // Send by peers at the beginning of each training round to get the list + // Send by peers at the beginning of each training round to notify + // the server that they want to join the round + case MessageTypes.JoinRound: { + this.#roundPeers = this.#roundPeers.set(peerId, false) + break + } + // Send by peers when they are ready to exchange weight updates to get the list // of active peers for this round. case MessageTypes.PeerIsReady: { - const peers = this.readyNodes.add(peerId) - if (peers.size >= minimumReadyPeers) { - this.readyNodes = Set() - - peers - .map((id) => { - const readyPeerIDs: messages.PeersForRound = { - type: MessageTypes.PeersForRound, - peers: peers.delete(id).toArray() - } - const encoded = msgpack.encode(readyPeerIDs) - return [id, encoded] as [client.NodeID, Buffer] - }) - .map(([id, encoded]) => { - const conn = this.connections.get(id) - if (conn === undefined) { - throw new Error(`peer ${id} marked as ready but not connection to it`) - } - return [conn, encoded] as [WebSocket, Buffer] - }).forEach(([conn, encoded]) => { conn.send(encoded) } - ) - } else { - this.readyNodes = peers - } + this.#roundPeers = this.#roundPeers.set(peerId, true) + debug("Received peer ready from: %o", shortId) + this.sendPeersForRoundIfNeeded() break } // Forwards a peer's message to another destination peer @@ -105,5 +90,61 @@ export class DecentralizedController extends TrainingController { debug("when processing WebSocket message: %o", e); } }) + // Setup callback for client leaving the session + ws.on('close', () => { + // Remove the participant when the websocket is closed + this.connections = this.connections.delete(peerId) + this.#roundPeers = this.#roundPeers.delete(peerId) + debug("client [%s] left", shortId) + + // Check if we are already waiting for new participants to join + if (this.waitingForMoreParticipants) return + // If no, check if we are still above the minimum number of participant required + if (this.connections.size >= minNbOfParticipants) { + // Check if remaining peers are all ready to exchange weight updates + this.sendPeersForRoundIfNeeded() + return + } + // If we are below the minimum number of participants + // tell remaining participants to wait until more participants join + this.sendWaitForMoreParticipantsMsg() + }) + } + /** + * Check if we have enough participants to start the training + * and if all peers that joined the round are ready to exchange weight updates + * If so, send the list of peers for this round to all participants + */ + private sendPeersForRoundIfNeeded(): void { + const minNbOfParticipants = this.task.trainingInformation.minNbOfParticipants + const nbOfPeersReady = this.#roundPeers.filter(ready => ready).size + // First check if there are enough participants to start the round + // Then check if all peers that wanted to join this round are ready + if (nbOfPeersReady < minNbOfParticipants + || nbOfPeersReady != this.#roundPeers.size) return + // Once every peer that joined the round is ready, we can start the round + this.#roundPeers.keySeq() + .map((id) => { + const readyPeerIDs: messages.PeersForRound = { + type: MessageTypes.PeersForRound, + peers: this.#roundPeers.delete(id).keySeq().toArray(), + aggregationRound: this.#aggregationRound + } + debug("Sending peer list to: %o", id.slice(0, 4)) + + const encoded = msgpack.encode(readyPeerIDs) + return [id, encoded] as [client.NodeID, Buffer] + }) + .map(([id, encoded]) => { + const conn = this.connections.get(id) + if (conn === undefined) { + throw new Error(`peer ${id} marked as ready but not connection to it`) + } + return [conn, encoded] as [WebSocket, Buffer] + }).forEach(([conn, encoded]) => { conn.send(encoded) }) + // empty the list of peers for the next round + this.#roundPeers = Map() + this.#aggregationRound++ } } + diff --git a/server/src/controllers/federated_controller.ts b/server/src/controllers/federated_controller.ts index 988a24a34..aa1e760e9 100644 --- a/server/src/controllers/federated_controller.ts +++ b/server/src/controllers/federated_controller.ts @@ -1,11 +1,9 @@ - import createDebug from "debug"; import WebSocket from 'ws' import { v4 as randomUUID } from 'uuid' import msgpack from 'msgpack-lite' -import { Map } from 'immutable' -import type { EncodedWeights, Task, WeightsContainer } from '@epfml/discojs' +import type { EncodedWeights, Task } from '@epfml/discojs' import { aggregator as aggregators, client, @@ -25,95 +23,22 @@ export class FederatedController extends TrainingController { By default the server waits for 100% of the nodes to send their contributions before aggregating the updates */ #aggregator = new aggregators.MeanAggregator(undefined, 1, 'relative') - /** - * Promise containing the current round's results. To be awaited on when providing clients - * with the most recent result. - */ - #result: Promise | undefined = undefined /** * The most up to date global weights. The model weights are already serialized and * can be sent to participants, before starting training, or when joining mid-training * or staled participants */ #latestGlobalWeights: EncodedWeights - /** - * Boolean used to know if we have enough participants to train or if - * we should be waiting for more - */ - #waitingForMoreParticipants = true - /** - * List of active participants along with their websockets - * the list allows updating participants about the training status - * i.e. waiting for more participants or resuming training - */ - #participants = Map() + constructor(task: Task, initialWeights: EncodedWeights) { super(task) this.#latestGlobalWeights = initialWeights - // start the perpetual promise loop - void this.storeAggregationResult() - } - /** - * Loop creating an aggregation result promise at each round. - * Because clients contribute to the round asynchronously, a promise is used to let them wait - * until the server has aggregated the weights. This loop creates a promise whenever the previous - * one resolved and awaits until it resolves. The promise is used in createPromiseForWeights. - * @param aggregator The aggregation handler - */ - private async storeAggregationResult (): Promise { - // Create a promise on the future aggregated weights - // Store the promise such that it is accessible from other methods - this.#result = new Promise((resolve) => this.#aggregator.once('aggregation', resolve)) - // The promise resolves once the server received enough contributions (through the handle method) - // and the aggregator aggregated the weights. - const globalModel = await this.#result - const serializedWeights = await serialization.weights.encode(globalModel) - this.#latestGlobalWeights = serializedWeights - - // Create a new promise for the next round - // TODO weird usage, should be handled inside of aggregator - void this.storeAggregationResult() - } - /** - * This method is called after received a local update. - * It puts the client on hold until the server has aggregated the weights - * by creating a Promise which will resolve once the server has received - * enough contributions. Relying on a promise is useful since clients may - * send their contributions at different times and a promise lets the server - * wait asynchronously for the results - * - * @param task the task to which the client is contributing - * @param aggregator the server aggregator, in order to access the current round - * @param ws the websocket through which send the aggregated weights - */ - private createPromiseForWeights (ws: WebSocket): void { - const promisedResult = this.#result - if (promisedResult === undefined) { - throw new Error(`result promise was not set`) - } - - // Wait for aggregation result to resolve with timeout, giving the network a time window - // to contribute to the model - void Promise.race([ - promisedResult, - client.timeout(30_000, "Timeout while waiting for enough participant contributions") //TODO: it doesn't make sense that the server is using the client utils' timeout - ]).then((result) => - [result, this.#aggregator.round] as [WeightsContainer, number]) - .then(async ([result, round]) => - [await serialization.weights.encode(result), round] as [serialization.weights.Encoded, number]) - .then(([serialized, round]) => { - debug("Sending global weights for round %o", round) - const msg: FederatedMessages.ReceiveServerPayload = { - type: MessageTypes.ReceiveServerPayload, - round, // send the current round number after aggregation - payload: serialized, - nbOfParticipants: this.#participants.size - } - ws.send(msgpack.encode(msg)) - }) - .catch((e) => debug("while waiting for weights: %o", e)) + // Save the latest weight updates to be able to send it to new or outdated clients + this.#aggregator.on('aggregation', async (weightUpdate) => { + this.#latestGlobalWeights = await serialization.weights.encode(weightUpdate) + }) } /** @@ -134,6 +59,7 @@ export class FederatedController extends TrainingController { while (!this.#aggregator.registerNode(clientId)) { clientId = randomUUID() } + const shortId = clientId.slice(0, 4) // Setup callbacks triggered upon receiving the different client messages ws.on('message', (data: Buffer) => { @@ -143,46 +69,28 @@ export class FederatedController extends TrainingController { return // TODO send back error } - // Currently expect two types of messages from clients: - // - a client connects to the task - // - a client sends a weight update + // Currently expect two types of message: + // - the client connects to the task + // - the client sends a weight update switch (msg.type) { /* * A new participant joins the task */ case MessageTypes.ClientConnected: { - debug(`client [%s] joined ${this.task.id}`, clientId.slice(0, 4)) - this.#participants = this.#participants.set(clientId, ws) // add the new client + debug(`client [%s] joined ${this.task.id}`, shortId) + this.connections = this.connections.set(clientId, ws) // add the new client - const waitForMoreParticipants = this.#participants.size < minNbOfParticipants const msg: FederatedMessages.NewFederatedNodeInfo = { type: MessageTypes.NewFederatedNodeInfo, id: clientId, - waitForMoreParticipants, + waitForMoreParticipants: this.connections.size < minNbOfParticipants, payload: this.#latestGlobalWeights, round: this.#aggregator.round, - nbOfParticipants: this.#participants.size + nbOfParticipants: this.connections.size } ws.send(msgpack.encode(msg)) - - debug("Wait for more participant flag: %o", waitForMoreParticipants) - - // If we were previously waiting for more participants to join and we now have enough, - // broadcast to previously waiting participants that the training can start - if (this.#waitingForMoreParticipants && !waitForMoreParticipants) { - this.#participants - // filter out the client that just joined as - // it already knows via the NewFederatedNodeInfo message - .filter((_, id) => id !== clientId) - .forEach((participantWs, participantId) => { - debug("Sending enough-participant message to client [%s]", participantId.slice(0, 4)) - const msg: FederatedMessages.EnoughParticipants = { - type: MessageTypes.EnoughParticipants - } - participantWs.send(msgpack.encode(msg)) - }) - } - this.#waitingForMoreParticipants = waitForMoreParticipants // update the attribute + // Send an update to participants if we can start/resume training + this.sendEnoughParticipantsMsgIfNeeded(clientId) break } /* @@ -191,21 +99,28 @@ export class FederatedController extends TrainingController { case MessageTypes.SendPayload: { const { payload, round } = msg if (this.#aggregator.isValidContribution(clientId, round)) { - // We need to create a promise waiting for the global model before adding the contribution to the aggregator - // (so that the aggregation and sending the global model to participants - // doesn't happen before the promise is created) - this.createPromiseForWeights(ws) - // This is assuming that the federated server's aggregator - // always works with a single communication round const weights = serialization.weights.decode(payload) - const addedSuccessfully = this.#aggregator.add(clientId, weights, round) - if (!addedSuccessfully) throw new Error("Aggregator's isValidContribution returned true but failed to add the contribution") - debug(`Successfully added contribution from client [%s] for round ${round}`, clientId.slice(0, 4)) + + // Create a callback to send the aggregated weight to the client + // when enough contributions are received + this.#aggregator.once('aggregation', async (weightUpdate) => { + debug("Sending global weights for round %o to client [%s]", this.#aggregator.round, shortId) + const msg: FederatedMessages.ReceiveServerPayload = { + type: MessageTypes.ReceiveServerPayload, + round: this.#aggregator.round, // send the current round number after aggregation + payload: await serialization.weights.encode(weightUpdate), + nbOfParticipants: this.connections.size + } + ws.send(msgpack.encode(msg)) + }) + // Add the contribution + this.#aggregator.add(clientId, weights, round) + debug(`Successfully added contribution from client [%s] for round ${round}`, shortId) } else { // If the client sent an invalid or outdated contribution // the server answers with the current round and last global model update debug(`Dropped contribution from client [%s] for round ${round} ` + - `Sending last global model from round ${this.#aggregator.round - 1}`, clientId.slice(0, 4)) + `Sending last global model from round ${this.#aggregator.round - 1}`, shortId) // no latest model at the first round if (this.#latestGlobalWeights === undefined) return @@ -213,7 +128,7 @@ export class FederatedController extends TrainingController { type: MessageTypes.ReceiveServerPayload, round: this.#aggregator.round - 1, // send the model from the previous round payload: this.#latestGlobalWeights, - nbOfParticipants: this.#participants.size + nbOfParticipants: this.connections.size } ws.send(msgpack.encode(msg)) } @@ -225,26 +140,18 @@ export class FederatedController extends TrainingController { // Setup callback for client leaving the session ws.on('close', () => { // Remove the participant when the websocket is closed - this.#participants = this.#participants.delete(clientId) + this.connections = this.connections.delete(clientId) this.#aggregator.removeNode(clientId) - debug("client [%s] left", clientId.slice(0, 4)) + debug("client [%s] left", shortId) // Check if we dropped below the minimum number of participant required // or if we are already waiting for new participants to join - if (this.#participants.size >= minNbOfParticipants || - this.#waitingForMoreParticipants + if (this.connections.size >= minNbOfParticipants || + this.waitingForMoreParticipants ) return - this.#waitingForMoreParticipants = true - // Tell remaining participants to wait until more participants join - this.#participants - .forEach((participantWs, participantId) => { - debug("Telling remaining client [%s] to wait for participants", participantId.slice(0, 4)) - const msg: FederatedMessages.WaitingForMoreParticipants = { - type: MessageTypes.WaitingForMoreParticipants - } - participantWs.send(msgpack.encode(msg)) - }) - }) + // tell remaining participants to wait until more participants join + this.sendWaitForMoreParticipantsMsg() + }) } } \ No newline at end of file diff --git a/server/src/controllers/training_controller.ts b/server/src/controllers/training_controller.ts index 5efb1f339..0a1c36f07 100644 --- a/server/src/controllers/training_controller.ts +++ b/server/src/controllers/training_controller.ts @@ -1,7 +1,13 @@ +import createDebug from "debug"; import type WebSocket from 'ws' +import { Map } from 'immutable' +import msgpack from 'msgpack-lite' +import { client } from '@epfml/discojs' import type { Task } from '@epfml/discojs' +const debug = createDebug("server:controllers") + /** * The Controller abstraction is commonly used in Express * and comes from the MVC pattern (model-view-controller) @@ -18,10 +24,64 @@ import type { Task } from '@epfml/discojs' * */ export abstract class TrainingController { + /** + * Boolean used to know if we have enough participants to train or if + * we should be waiting for more + */ + protected waitingForMoreParticipants = true + /** + * List of active participants along with their websockets + * the list allows updating participants about the training status + * i.e. waiting for more participants or resuming training + */ + protected connections = Map() constructor(protected readonly task: Task) { } abstract handle( ws: WebSocket ): void + + /** + * If enough participants joined, notifies them that the training can start/resume + * + * @param currentId the id of the participant that just joined + */ + protected sendEnoughParticipantsMsgIfNeeded(currentId: client.NodeID) { + // If we are currently waiting for more participants to join and we now have enough, + // broadcast to previously waiting participants that the training can start + if (this.waitingForMoreParticipants && + this.connections.size >= this.task.trainingInformation.minNbOfParticipants) { + this.connections + // filter out the client that just joined as + // it already knows via the NewFederatedNodeInfo message + .delete(currentId) + .forEach((participantWs, participantId) => { + debug("Sending enough-participant message to client [%s]", participantId.slice(0, 4)) + const msg: client.messages.EnoughParticipants = { + type: client.messages.type.EnoughParticipants + } + participantWs.send(msgpack.encode(msg)) + }) + this.waitingForMoreParticipants = false // update the attribute + } + } + + /** + * Notifies participant that the number of participants drops below the minimum threshold + */ + protected sendWaitForMoreParticipantsMsg(): void { + // If we are below the minimum number of participants + // tell remaining participants to wait until more participants join + this.waitingForMoreParticipants = true + this.connections + .forEach((participantWs, participantId) => { + debug("Telling remaining client [%s] to wait for participants", participantId.slice(0, 4)) + const msg: client.messages.WaitingForMoreParticipants = { + type: client.messages.type.WaitingForMoreParticipants + } + participantWs.send(msgpack.encode(msg)) + }) + } + } \ No newline at end of file diff --git a/server/src/main.ts b/server/src/main.ts index 2fd7d6423..dbf33901a 100644 --- a/server/src/main.ts +++ b/server/src/main.ts @@ -11,8 +11,6 @@ import { Server } from "./server.js"; const PORT = 8080; const providers = Object.values(defaultTasks); -// Init the server with default tasks -const server = await Server.of(...providers); console.info("Server loaded the tasks below"); console.table( @@ -26,5 +24,6 @@ console.table( })), ); -const [_, serverURL] = await server.serve(PORT); +// Init the server with default tasks +const [_, serverURL] = await new Server().serve(PORT, ...providers); console.log(`Disco Server listening on ${serverURL.toString()}`); diff --git a/server/src/routes/task_router.ts b/server/src/routes/task_router.ts index d2af22928..c6d249396 100644 --- a/server/src/routes/task_router.ts +++ b/server/src/routes/task_router.ts @@ -6,23 +6,23 @@ import { Set } from 'immutable' import type { Task, TaskID } from '@epfml/discojs' import { serialization, isTask } from '@epfml/discojs' -import type { TaskInitializer } from '../task_initializer.js' +import type { TaskSet } from '../task_set.js' const debug = createDebug("server:router:task_router"); export class TaskRouter { readonly #expressRouter: express.Router - readonly #taskInitializer: TaskInitializer + readonly #taskSet: TaskSet - constructor(taskInitializer: TaskInitializer) { - this.#taskInitializer = taskInitializer + constructor(taskSet: TaskSet) { + this.#taskSet = taskSet this.#expressRouter = express.Router() // Return available tasks upon GET requests this.#expressRouter.get('/', (_, res) => { res .status(200) - .send(this.#taskInitializer.tasks.map(([t, _]) => t).toArray()) + .send(this.#taskSet.tasks.map(([t, _]) => t).toArray()) }) // POST request to add a new task @@ -43,7 +43,7 @@ export class TaskRouter { if (!serialization.model.isEncoded(encoded)) throw new Error("could not recognize model encoding") - this.#taskInitializer.addTask(newTask, encoded) + this.#taskSet.addTask(newTask, encoded) .then(() => res.status(200).end("Successful task upload")) .catch((e) => { debug("while adding model: %o", e); @@ -54,10 +54,7 @@ export class TaskRouter { // delay listener because `this` (object) isn't fully constructed yet process.nextTick(() => { // a 'newTask' event is emitted when a new task is added - this.#taskInitializer.on('newTask', (t, _) => { - this.onNewTask(t) - return Promise.resolve() - }) + this.#taskSet.on('newTask', ({ task }) => this.onNewTask(task)) }) } @@ -91,7 +88,7 @@ export class TaskRouter { response.status(404) return } - const taskAndModel = this.#taskInitializer.tasks.find(([t, _]) => t.id === id) + const taskAndModel = this.#taskSet.tasks.find(([t, _]) => t.id === id) if (taskAndModel === undefined) { response.status(404) return diff --git a/server/src/routes/training_router.ts b/server/src/routes/training_router.ts index 7d9f567c3..c2b107ff9 100644 --- a/server/src/routes/training_router.ts +++ b/server/src/routes/training_router.ts @@ -4,7 +4,7 @@ import { Set } from 'immutable' import type { Task, EncodedModel } from '@epfml/discojs' import { serialization } from '@epfml/discojs' -import type { TaskInitializer } from '../task_initializer.js' +import type { TaskSet } from '../task_set.js' import { TrainingController, FederatedController, DecentralizedController } from '../controllers/index.js' /** @@ -19,7 +19,7 @@ export class TrainingRouter { #tasks = Set() constructor(private readonly trainingScheme: 'federated' | 'decentralized', - wsApplier: expressWS.Instance, taskInitializer: TaskInitializer) { + wsApplier: expressWS.Instance, taskSet: TaskSet) { this.#expressRouter = express.Router() wsApplier.applyTo(this.#expressRouter) @@ -28,13 +28,15 @@ export class TrainingRouter { /* delay listener because `this` (object) isn't fully constructed yet. * The lambda function inside process.nextTick is executed after the current operation * on the JS stack runs to completion and before the event loop is allowed to continue. - * this.onNewTask is registered as a listener to taskInitializer, which has 2 consequences: - * - this.onNewTask is executed on all the default tasks (which are already loaded in taskInitializer) - * - Every time a new task and model are added to taskInitializer, this.onNewTask is executed on them. + * this.onNewTask is registered as a listener to taskSet, which has 2 consequences: + * - this.onNewTask is executed on all the default tasks (which are already loaded in taskSet) + * - Every time a new task and model are added to taskSet, this.onNewTask is executed on them. * For every task and model, this.onNewTask creates a path /taskID and routes it to this.handle. */ process.nextTick(() => { - taskInitializer.on('newTask', async (t, m) => { await this.onNewTask(t, m) }) + taskSet.on('newTask', + async ({ task, encodedModel }) => { await this.onNewTask(task, encodedModel) } + ) }) } diff --git a/server/src/server.ts b/server/src/server.ts index c26cabcf2..9a534c128 100644 --- a/server/src/server.ts +++ b/server/src/server.ts @@ -7,7 +7,7 @@ import type * as http from "http"; import type { TaskProvider } from "@epfml/discojs"; import { TaskRouter, TrainingRouter } from './routes/index.js' -import { TaskInitializer } from "./task_initializer.js"; +import { TaskSet } from "./task_set.js"; const debug = createDebug("server"); @@ -21,25 +21,21 @@ const debug = createDebug("server"); * https://developer.mozilla.org/en-US/docs/Learn/Server-side/Express_Nodejs/Introduction */ export class Server { - readonly #taskInitializer = new TaskInitializer(); - - // Static method to asynchronously init the Server - static async of(...tasks: TaskProvider[]): Promise { - const ret = new Server(); - await Promise.all(tasks.map((t) => ret.addTask(t))); - return ret; - } + readonly #taskSet = new TaskSet(); async addTask(taskProvider: TaskProvider): Promise { - await this.#taskInitializer.addTask(taskProvider); + await this.#taskSet.addTask(taskProvider); } /** * start server * * @param port where to start, if not given, choose a random one + * @param tasks list of initial tasks to serve + * @returns a tuple with the server instance and the URL + * **/ - async serve(port?: number): Promise<[http.Server, URL]> { + async serve(port?: number, ...tasks: TaskProvider[]): Promise<[http.Server, URL]> { const wsApplier = expressWS(express(), undefined, { leaveRouterUntouched: true, }); @@ -50,9 +46,12 @@ export class Server { app.use(express.json({ limit: "50mb" })); app.use(express.urlencoded({ limit: "50mb", extended: false })); - const taskRouter = new TaskRouter(this.#taskInitializer) - const federatedRouter = new TrainingRouter('federated', wsApplier, this.#taskInitializer) - const decentralizedRouter = new TrainingRouter('decentralized', wsApplier, this.#taskInitializer) + const taskRouter = new TaskRouter(this.#taskSet) + const federatedRouter = new TrainingRouter('federated', wsApplier, this.#taskSet) + const decentralizedRouter = new TrainingRouter('decentralized', wsApplier, this.#taskSet) + // Important to add the tasks AFTER all the routers are initialized + // so that the 'newTask' event is emitted after the routers are ready + await Promise.all(tasks.map((t) => this.addTask(t))); wsApplier.getWss().on('connection', (ws, req) => { if (!federatedRouter.isValidUrl(req.url) && !decentralizedRouter.isValidUrl(req.url)) { diff --git a/server/src/task_initializer.ts b/server/src/task_set.ts similarity index 75% rename from server/src/task_initializer.ts rename to server/src/task_set.ts index 1bbb2a1de..2ac77be22 100644 --- a/server/src/task_initializer.ts +++ b/server/src/task_set.ts @@ -1,19 +1,19 @@ -import { List, Set } from 'immutable' +import { Set } from 'immutable' import fs from 'node:fs/promises' import tf from '@tensorflow/tfjs' import '@tensorflow/tfjs-node' import { Task, TaskProvider, isTask, - serialization, models, Model + serialization, models, Model, EventEmitter } from '@epfml/discojs' import type { EncodedModel } from '@epfml/discojs' /** - * The TaskInitializer essentially handles initializing a Task and - * its associated EncodedModel. + * The TaskSet essentially handles initializing a Task and + * loading its associated EncodedModel. * - * We rely on a TaskInitializer to abstract the (asynchronous) logic of getting the model + * We rely on a TaskSet to abstract the (asynchronous) logic of getting the model * when not provided. * Depending on the case, getting the model is done by reading the model files * from disk if they exists, downloading them from a URL or @@ -23,17 +23,17 @@ import type { EncodedModel } from '@epfml/discojs' * to clients. Since the server doesn't need to use the Model, we * simply leave it already encoded and ready to be sent to clients * - * Due to the asynchronous nature of `addTask`, TaskInitializer is an EventEmitter, + * Due to the asynchronous nature of `addTask`, TaskSet is an EventEmitter, * by registering callbacks on new tasks and emitting a 'newTask' event * when a new task has been added. * - * Tasks are usually passed to TaskInitializer when booting the server + * Tasks are usually passed to TaskSet when booting the server * and objects depending on tasks and models can subscribe to * the 'newTask' event to run callbacks whenever a new Task and EncodedModel are initialized. */ -export class TaskInitializer { - // List of callback to apply to future task-model pairs added - private listeners = List<(t: Task, m: EncodedModel) => Promise>() +export class TaskSet extends EventEmitter<{ + "newTask": { task: Task, encodedModel: EncodedModel } +}>{ // Keep track of previously initialized task-model pairs #tasks = Set<[Task, EncodedModel]>() @@ -41,21 +41,6 @@ export class TaskInitializer { return this.#tasks } - // Register a callback to be ran on all tasks - on(_: 'newTask', callback: (t: Task, m: EncodedModel) => Promise): void { - // Apply the callback to already initialized task-model pairs - this.#tasks.forEach(async ([t, m]) => { await callback(t, m) }) - // Register the callback that will be ran when new tasks are added - this.listeners = this.listeners.push(callback) - } - - // Emit a 'newTask' event, - // It runs all the registered callbacks with the new task and model - #emit(_: 'newTask', task: Task, model: EncodedModel): void { - // Run all the callbacks on the newly added task - this.listeners.forEach(async (listener) => { await listener(task, model) }) - } - /** * Method to add a new task and optionally its associated model. * It accepts parameters in different formats and handles @@ -97,7 +82,7 @@ export class TaskInitializer { // Add the task-model pair to the set this.#tasks = this.#tasks.add([task, encodedModel]) - this.#emit('newTask', task, encodedModel) + this.emit('newTask', { task, encodedModel }) } /** diff --git a/server/tests/client/decentralized.spec.ts b/server/tests/client/decentralized.spec.ts index 8f6a92ad8..b042567a8 100644 --- a/server/tests/client/decentralized.spec.ts +++ b/server/tests/client/decentralized.spec.ts @@ -16,8 +16,7 @@ function test ( let server: http.Server let url: URL beforeEach(async () => { - const disco = await Server.of(TASK); - [server, url] = await disco.serve(); + [server, url] = await new Server().serve(undefined, TASK); }); afterEach(() => { server?.close() }) diff --git a/server/tests/client/federated.spec.ts b/server/tests/client/federated.spec.ts index 299a9e412..60041d137 100644 --- a/server/tests/client/federated.spec.ts +++ b/server/tests/client/federated.spec.ts @@ -15,8 +15,7 @@ describe("federated client", () => { let server: http.Server; let url: URL; beforeEach(async () => { - const disco = await Server.of(TASK_PROVIDER); - [server, url] = await disco.serve(); + [server, url] = await new Server().serve(undefined, TASK_PROVIDER); }); afterEach(() => { server?.close(); diff --git a/server/tests/e2e/decentralized.spec.ts b/server/tests/e2e/decentralized.spec.ts index 89f90f701..768ab7f62 100644 --- a/server/tests/e2e/decentralized.spec.ts +++ b/server/tests/e2e/decentralized.spec.ts @@ -1,11 +1,17 @@ import type * as http from 'node:http' -import { List } from 'immutable' +import { List, Repeat } from 'immutable' import { expect } from 'chai' +import path from "node:path"; + +import type { RoundStatus } from "@epfml/discojs"; +import { loadImagesInDir } from "@epfml/discojs-node"; +import { Queue } from './utils.js' import { aggregator as aggregators, client as clients, defaultTasks, + Disco, WeightsContainer, } from "@epfml/discojs"; @@ -36,8 +42,7 @@ describe('end-to-end decentralized', function () { let server: http.Server let url: URL beforeEach(async () => { - const disco = await Server.of(defaultTasks.cifar10); - [server, url] = await disco.serve(); + [server, url] = await new Server().serve(undefined, defaultTasks.cifar10, defaultTasks.lusCovid); }); afterEach(() => { server?.close() }) @@ -114,4 +119,165 @@ describe('end-to-end decentralized', function () { it('several rounds of cifar 10 with three secure aggregators yields consensus', async () => { await reachConsensus('secure', 3) }) + + it("peers emit expected statuses", async function () { + this.timeout(15_000); + const lusCovidTask = defaultTasks.lusCovid.getTask(); + lusCovidTask.trainingInformation = { + ...lusCovidTask.trainingInformation, + scheme: 'decentralized', + aggregationStrategy: 'mean', + epochs: 8, + minNbOfParticipants: 2, + } + + const DATASET_DIR = path.join("..", "datasets"); + + const [positive, negative] = [ + ( + await loadImagesInDir(path.join(DATASET_DIR, "lus_covid", "COVID+")) + ).zip(Repeat("COVID-Positive")), + ( + await loadImagesInDir(path.join(DATASET_DIR, "lus_covid", "COVID-")) + ).zip(Repeat("COVID-Negative")), + ]; + const dataset = positive.chain(negative); + + /** + * Then at each round (each call to `disco.trainByRound`) the event cycle is: + * a) During onRoundBeingCommunication, + * 1. the peer notifies the server that they want to join the next round + * 2. finishes by updating the status to "local training" + * (without waiting for a server answer) + * b) local training (the status remains "local training") + * c) During onRoundEndCommunication + * 1. the peer notifies the server that they are ready to share weights + * set status to "connecting to peers" + * 2. wait for the server to answer with the current round's peers list + * this is where the nb of participants is updated + * 3. establish peer-to-peer connections + * 4. set status to "updating model" and exchange weight updates + * + * Given this, it is important to note that calling disco.trainByRound().next() + * for the first time will perform a) and then b) where it stops and yields the round logs. + * Thus, c) isn't called and the weight sharing is not performed during this call to next(). + * Calling next() again will then run c), as well as a) and b) again. + * + * In this test the timeline is: + * - User 1 joins the task by themselves + * - User 2 joins + * - User 1 leaves + * - User 3 joins + * - User 2 & 3 leave + */ + const statusUpdateTime = 500 // allow some time for the client to update their status + + // Create User 1 + const discoUser1 = new Disco(lusCovidTask, url, { }); + const statusUser1 = new Queue(); + discoUser1.on("status", status => { statusUser1.put(status) }) + const generatorUser1 = discoUser1.trainByRound(["image", dataset]) + + // Have User 1 join the task and train locally for one round + const logUser1Round1 = await generatorUser1.next() + expect(logUser1Round1.done).to.be.false + // User 1 did a) and b) so their status should be Training + expect(await statusUser1.next()).equal("local training") + if (logUser1Round1.done) + throw Error("User 1 finished training at the 1st round") + // participant list not updated yet (updated at step c)) + expect((logUser1Round1.value).participants).equal(1) + + // Calling next() a 2nd time makes User 1 go to c) where the peer should + // stay stuck awaiting until another participant joins + const logUser1Round2Promise = generatorUser1.next() + await new Promise((res, _) => setTimeout(res, statusUpdateTime)) // Wait some time for the status to update + expect(await statusUser1.next()).equal("connecting to peers") // tries to connect to peers + expect(await statusUser1.next()).equal("not enough participants") // but has to wait for more participants + + // Create User 2 + const discoUser2 = new Disco(lusCovidTask, url, { }); + const statusUser2 = new Queue(); + discoUser2.on("status", status => { statusUser2.put(status) }) + const generatorUser2 = discoUser2.trainByRound(["image", dataset]) + + // Have User 2 join the task and train for one round + const logUser2Round1 = await generatorUser2.next() + expect(logUser2Round1.done).to.be.false + if (logUser2Round1.done) + throw Error("User 2 finished training at the 1st round") + // participant list not updated yet (updated at step c)) + expect((logUser2Round1.value).participants).equal(1) + // User 2 did a) and b) + expect(await statusUser2.next()).equal("local training") + // User 1 is still in c) now waiting for user 2 to be ready to exchange weight updates + expect(await statusUser1.next()).equal("connecting to peers") + + // Proceed with round 2 + // The server should answer with the round's peers list. + // Peers then exchange updates and then start training locally with the new weights + const logUser2Round2 = await generatorUser2.next() + const logUser1Round2 = await logUser1Round2Promise // the promise can resolve now + expect(logUser1Round2.done).to.be.false + expect(logUser2Round2.done).to.be.false + if (logUser1Round2.done || logUser2Round2.done) + throw Error("User 1 or 2 finished training at the 2nd round") + // nb of participants should now be updated + expect((logUser1Round2.value).participants).equal(2) + expect((logUser2Round2.value).participants).equal(2) + // User 1 and 2 did c), a) and b) + expect(await statusUser1.next()).equal("updating model") // second to last + expect(await statusUser1.next()).equal("local training") + + expect(await statusUser2.next()).equal("connecting to peers") // back to connecting when user 1 joins + expect(await statusUser2.next()).equal("updating model") + expect(await statusUser2.next()).equal("local training") + + // Have user 1 quit the session + await discoUser1.close() + // Make user 2 go to c) + const logUser2Round3Promise = generatorUser2.next() + await new Promise((res, _) => setTimeout(res, statusUpdateTime)) // Wait some time for the status to update + expect(await statusUser2.next()).equal("connecting to peers") + expect(await statusUser2.next()).equal("not enough participants") + + // Create User 3 + const discoUser3 = new Disco(lusCovidTask, url, { }); + const statusUser3 = new Queue(); + discoUser3.on("status", status => { statusUser3.put(status) }) + const generatorUser3 = discoUser3.trainByRound(["image", dataset]) + + // User 3 joins mid-training and trains one local round + const logUser3Round1 = await generatorUser3.next() + expect(logUser3Round1.done).to.be.false + if (logUser3Round1.done) + throw Error("User 3 finished training at the 1st round") + // participant list not updated yet + expect((logUser3Round1.value).participants).equal(1) + // User 3 did a) and b) + expect(await statusUser3.next()).equal("local training") + // User 2 is still in c) waiting for user 3 to be ready to exchange waits + expect(await statusUser2.next()).equal("connecting to peers") + + // User 3 notifies the server that they are ready to exchange waits + // then user 2 and 3 exchange weight updates + const logUser3Round3 = await generatorUser3.next() + const logUser2Round3 = await logUser2Round3Promise // the promise can resolve now + if (logUser3Round3.done || logUser2Round3.done) + throw Error("User 1 or 2 finished training at the 3rd round") + expect(logUser2Round3.value.participants).equal(2) + expect(logUser3Round3.value.participants).equal(2) + // both user 2 and 3 did c), a) and are now in b) + expect(await statusUser2.next()).equal("updating model") + expect(await statusUser2.next()).equal("local training") + + expect(await statusUser3.next()).equal("connecting to peers") + expect(await statusUser3.next()).equal("updating model") + expect(await statusUser3.next()).equal("local training") + + await discoUser2.close() + await new Promise((res, _) => setTimeout(res, statusUpdateTime)) // Wait some time for the status to update + expect(await statusUser3.next()).equal("not enough participants") + await discoUser3.close() + }); }) diff --git a/server/tests/e2e/federated.spec.ts b/server/tests/e2e/federated.spec.ts index bec5a74f4..99794645c 100644 --- a/server/tests/e2e/federated.spec.ts +++ b/server/tests/e2e/federated.spec.ts @@ -28,12 +28,13 @@ describe("end-to-end federated", () => { let url: URL; beforeEach(async function () { this.timeout("10s"); - [server, url] = await Server.of( + [server, url] = await new Server().serve( + undefined, defaultTasks.cifar10, defaultTasks.lusCovid, defaultTasks.titanic, defaultTasks.wikitext, - ).then((s) => s.serve()); + ); }); afterEach(() => { server?.close(); @@ -192,8 +193,8 @@ describe("end-to-end federated", () => { * When disco.trainByRound is called for the first time, the client connects to the server * which returns the latest model, current round and nb of participants. * Then at each round the event cycle is: - * a) onRoundBeingCommunication which updates the status to TRAINING - * b) local training (the status remains TRAINING) + * a) onRoundBeingCommunication which updates the status to "local training" + * b) local training (the status remains "local training") * c) onRoundEndCommunication which sends the local update and * receives the global weights while emitting the status UPDATE * @@ -210,9 +211,6 @@ describe("end-to-end federated", () => { * - User 3 joins * - User 2 & 3 leave */ - const TRAINING: RoundStatus = "Training the model on the data you connected" - const WAITING: RoundStatus = "Waiting for more participants" - const UPDATING: RoundStatus = "Updating the model with other participants' models" const statusUpdateTime = 500 // Create User 1 @@ -231,7 +229,7 @@ describe("end-to-end federated", () => { // stay stuck awaiting until another participant joins const logUser1Round2Promise = generatorUser1.next() await new Promise((res,_) => setTimeout(res, statusUpdateTime)) // Wait some time for the status to update - expect(statusUser1).equal(WAITING) + expect(statusUser1).equal("not enough participants") // Create User 2 const discoUser2 = new Disco(lusCovidTask, url, { scheme: "federated" }); @@ -244,10 +242,10 @@ describe("end-to-end federated", () => { expect(logUser2Round1.done).to.be.false expect((logUser2Round1.value as RoundLogs).participants).equal(2) // User 2 did a) and b) - expect(statusUser2).equal(TRAINING) + expect(statusUser2).equal("local training") // User 1 is still in c) now waiting for user 2 to share their local update // and for the server to aggregate the local updates - expect(statusUser1).equal(UPDATING) + expect(statusUser1).equal("updating model") // Proceed with round 2 // the server should answer with the new global weights @@ -259,15 +257,15 @@ describe("end-to-end federated", () => { expect((logUser1Round2.value as RoundLogs).participants).equal(2) expect((logUser2Round2.value as RoundLogs).participants).equal(2) // User 1 and 2 did c), a) and b) - expect(statusUser1).equal(TRAINING) - expect(statusUser2).equal(TRAINING) + expect(statusUser1).equal("local training") + expect(statusUser2).equal("local training") // Have user 1 quit the session await discoUser1.close() // Make user 2 go to c) const logUser2Round3Promise = generatorUser2.next() await new Promise((res, _) => setTimeout(res, statusUpdateTime)) // Wait some time for the status to update - expect(statusUser2).equal(WAITING) + expect(statusUser2).equal("not enough participants") // Create User 3 const discoUser3 = new Disco(lusCovidTask, url, { scheme: "federated" }); @@ -280,10 +278,10 @@ describe("end-to-end federated", () => { expect(logUser3Round1.done).to.be.false expect((logUser3Round1.value as RoundLogs).participants).equal(2) // User 3 did a) and b) - expect(statusUser3).equal(TRAINING) + expect(statusUser3).equal("local training") // User 2 is still in c) waiting for user 3 to share their local update // and for the server to aggregate the local updates - expect(statusUser2).equal(UPDATING) + expect(statusUser2).equal("updating model") // User 3 sends their weights to the server const logUser3Round3 = await generatorUser3.next() @@ -294,12 +292,12 @@ describe("end-to-end federated", () => { expect(logUser2Round3.value.participants).equal(2) expect(logUser3Round3.value.participants).equal(2) // both user 2 and 3 did c), a) and are now in b) - expect(statusUser2).equal(TRAINING) - expect(statusUser3).equal(TRAINING) + expect(statusUser2).equal("local training") + expect(statusUser3).equal("local training") await discoUser2.close() await new Promise((res, _) => setTimeout(res, statusUpdateTime)) // Wait some time for the status to update - expect(statusUser3).equal(WAITING) + expect(statusUser3).equal("not enough participants") await discoUser3.close() }); }); diff --git a/server/tests/e2e/utils.ts b/server/tests/e2e/utils.ts new file mode 100644 index 000000000..342491a4b --- /dev/null +++ b/server/tests/e2e/utils.ts @@ -0,0 +1,20 @@ +import { List } from 'immutable'; + +export class Queue { + #content = List(); + + put(e: T) { + this.#content = this.#content.push(e); + } + + async next(): Promise { + for (;;) { + const ret = this.#content.first(); + if (ret !== undefined) { + this.#content = this.#content.shift() + return ret + } + await new Promise((resolve) => setTimeout(resolve, 10)); + } + } +} \ No newline at end of file diff --git a/tsconfig.base.json b/tsconfig.base.json index 7590b2094..fa7b3580a 100644 --- a/tsconfig.base.json +++ b/tsconfig.base.json @@ -9,7 +9,10 @@ "skipLibCheck": true, // don't pollute - "noEmitOnError": true + "noEmitOnError": true, + + // enforce using the override keyword + "noImplicitOverride": true } } diff --git a/webapp/src/components/task_creation_form/TaskForm.vue b/webapp/src/components/task_creation_form/TaskForm.vue index 1bc93dded..612f43bd6 100644 --- a/webapp/src/components/task_creation_form/TaskForm.vue +++ b/webapp/src/components/task_creation_form/TaskForm.vue @@ -24,8 +24,7 @@ field, { dataType, - scheme, - decentralizedSecure + scheme } )" > @@ -77,11 +76,6 @@ v-model="scheme" :field="field" /> - ()) -const setDecentralizedSecure = (v: boolean) => { decentralizedSecure.value = v } const formatSection = (section: FormSection, rawTask: any): any => { let fields = List(section.fields) @@ -297,7 +289,7 @@ const isFieldVisible = ( if (fieldDeps === undefined) { return true } - const potentialDependencies: Array = ['dataType', 'scheme', 'decentralizedSecure'] + const potentialDependencies: Array = ['dataType', 'scheme'] return potentialDependencies.every((key) => fieldDeps[key] !== dependencies[key]) } diff --git a/webapp/src/components/training/Description.vue b/webapp/src/components/training/Description.vue index 68075f307..2b6e3053b 100644 --- a/webapp/src/components/training/Description.vue +++ b/webapp/src/components/training/Description.vue @@ -109,7 +109,7 @@ const displayField = (section: FormSection, field: FormField): boolean => { if (deps === undefined) { return true } - const potentialDependencies: Array = ['dataType', 'scheme', 'decentralizedSecure'] + const potentialDependencies: Array = ['dataType', 'scheme'] return potentialDependencies.every((key) => props.task.trainingInformation[key] !== deps[key]) } return false diff --git a/webapp/src/components/training/Trainer.vue b/webapp/src/components/training/Trainer.vue index f9161db9d..2462fa72a 100644 --- a/webapp/src/components/training/Trainer.vue +++ b/webapp/src/components/training/Trainer.vue @@ -52,14 +52,14 @@
-
+
Status - {{ roundStatus }} + {{ roundStatus[1] }}
-
@@ -118,7 +118,7 @@