Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow decentralized users to join late and catch up #775

Merged
merged 56 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
e3894fc
Rename Base to Aggregator and simplify event cycle
JulienVig Sep 10, 2024
8a6df7d
Update aggregator test
JulienVig Sep 11, 2024
abcb3a3
Clarify todo
JulienVig Sep 11, 2024
0e04d59
Simplify receiving a weight update
JulienVig Sep 11, 2024
aa37b7e
Simplify log
JulienVig Sep 11, 2024
c558494
include aggregation round and communication round in decentralized me…
JulienVig Sep 11, 2024
b8578e8
Rename Training to TrainingSteps
JulienVig Sep 11, 2024
1770f5c
Update round status messages
JulienVig Sep 11, 2024
e2bbf60
Add comment for non-awaited promised
JulienVig Sep 11, 2024
bc2be30
Extend discojs EventEmitter
JulienVig Sep 11, 2024
74c1803
Handle async listener and running listener on past events
JulienVig Sep 11, 2024
7709a9d
update event emitter subscription
JulienVig Sep 11, 2024
65bac8c
Update decentralized payload message
JulienVig Sep 11, 2024
0057353
Fix import rename
JulienVig Sep 11, 2024
4bec8c7
Fix linting error
JulienVig Sep 11, 2024
7620ae3
Update decentralized round status
JulienVig Sep 12, 2024
d0e6380
Put messages in common between federated and decentralized
JulienVig Sep 12, 2024
e2d9a28
Move participant management logic to parent class
JulienVig Sep 12, 2024
61aeb90
Make peers wait when not enough participants
JulienVig Sep 12, 2024
8a76aa3
Move participant waiting to parent class
JulienVig Sep 12, 2024
9d1fe7e
Replace private keywords by private fields
JulienVig Sep 19, 2024
d5815ae
Add a new JoinRound decentralized message
JulienVig Sep 19, 2024
95740ca
Enable peers to join/leave mid session. Establish p2p connection in o…
JulienVig Sep 19, 2024
40efe48
Rm unused decentralizedSecure parameter
JulienVig Sep 23, 2024
e5b0ab3
Specify aggregation strategy in default tasks
JulienVig Sep 23, 2024
99b8661
rm commented code
JulienVig Sep 24, 2024
53d10c9
Add more FAQ troubleshooting
JulienVig Sep 24, 2024
311ec60
Fix decentralized participant number
JulienVig Sep 24, 2024
119b7c6
Test client status
JulienVig Sep 24, 2024
f374212
discojs/src/aggregator/aggregator: clean code following @tharvik's co…
JulienVig Oct 7, 2024
6bcd84b
discojs/src/aggregator: rm unused methods
JulienVig Oct 9, 2024
7272d36
discojs/src/aggregator/mean.spec: inline client names
JulienVig Oct 7, 2024
750fada
discojs/src/aggregator/mean.spec: move away from time-based tests
JulienVig Oct 9, 2024
9c62cb7
discojs/src/aggregator: split add method into common public add and p…
JulienVig Oct 7, 2024
05d0567
discojs/src/client: change this.shortID method to shortenID function
JulienVig Oct 8, 2024
6d88cac
discojs/src/client: rename checkIfWaitForParticipants to waitForParti…
JulienVig Oct 8, 2024
9941dfd
discojs/src/client/decentralized/decentralized_client: get rid of rou…
JulienVig Oct 7, 2024
d931ad6
discojs/src/client/decentralized/decentralized_client: send payloads …
JulienVig Oct 9, 2024
fd82f63
discojs/client/decentralized_client: save and emit "connecting to pee…
JulienVig Oct 9, 2024
be0b2a3
server/tests/e2e: rely on a Queue to test RoundStatus
JulienVig Oct 9, 2024
28f0422
discojs/src/client|trainer: Make getNbOfParticipant parent method abs…
JulienVig Oct 8, 2024
066343f
discojs/src/default_tasks/mnist: make MNIST task use secure aggregation
JulienVig Oct 8, 2024
bcf2eca
server/decentralized_controllers: rename checkThenSendPeersForRound t…
JulienVig Oct 8, 2024
531e20e
server/federated_controller: use await instead of then
JulienVig Oct 8, 2024
3a6e794
server/decentralized_controller: invert condition to exit early
JulienVig Oct 8, 2024
1c3f61c
server/controllers: create common sendWaitForMoreParticipantsMsg method
JulienVig Oct 8, 2024
6438d41
server/controllers: clean checkIfEnoughParticipants and rename to sen…
JulienVig Oct 8, 2024
ee437e4
server/src: rename TaskInitializer to TaskSet
JulienVig Oct 9, 2024
b5449fd
server/tests/e2e/decentralized.spec: avoid casting as RoundLogs
JulienVig Oct 9, 2024
d44b786
*: renaming round statuses to lowercase
JulienVig Oct 8, 2024
5d20ac9
webapp/Trainer.vue: convert roundStatus type from string to [RoundSta…
JulienVig Oct 9, 2024
e2314c6
*: Provide initial tasks to the server when calling `serve` instead o…
JulienVig Oct 9, 2024
85b60a9
*: remove past event subscription on EventEmitter (now that tasks are…
JulienVig Oct 9, 2024
5c960cb
*: make aggregator.add check if contribution is valid
JulienVig Oct 7, 2024
dfb607d
*: Make `aggregator.add` return void to enforce listening to the 'agg…
JulienVig Oct 9, 2024
d75338a
*: enforce using override keyword
JulienVig Oct 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions cli/src/benchmark_gpt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ async function main(args: Required<CLIArguments>): Promise<void> {
contextLength, batchSize, modelPath } = args

// Launch a server instance
const disco = await Server.of(defaultTasks.wikitext);
const [server, url] = await disco.serve();
const [server, url] = await new Server().serve(undefined, defaultTasks.wikitext);

// Fetch the wikitext task from the server
const tasks = await fetchTasks(url)
Expand Down
3 changes: 1 addition & 2 deletions cli/src/cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ async function main (provider: TaskProvider, numberOfUsers: number): Promise<voi
console.log(`Started ${task.trainingInformation.scheme} training of ${task.id}`)
console.log({ args })

const discoServer = await Server.of(provider)
const [server, url] = await discoServer.serve()
const [server, url] = await new Server().serve(undefined, provider)

const data = await getTaskData(task)

Expand Down
18 changes: 11 additions & 7 deletions discojs/src/aggregator.spec.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { expect } from "chai";
import { Map, Range, Set } from "immutable";
import { Map, Range, Set, List } from "immutable";

import { WeightsContainer } from "./index.js";
import {
Expand Down Expand Up @@ -32,11 +32,14 @@ AGGREGATORS.forEach(([name, Aggregator]) =>
const results = new Promise((resolve) =>
aggregator.on("aggregation", resolve),
);


let promises = List<Promise<WeightsContainer>>()
for (let i = 0; i < 3; i++)
for (let r = 0; r < aggregator.communicationRounds; r++)
aggregator.add(`client ${i}`, WeightsContainer.of([i]), 0, r);

for (let r = 0; r < aggregator.communicationRounds; r++){
promises = promises.push(aggregator.getPromiseForAggregation())
aggregator.add(`client ${i}`, WeightsContainer.of([i]), 0, r)
}
await Promise.all(promises)
await results; // nothing to test

expect(aggregator.round).to.equal(1);
Expand All @@ -57,7 +60,7 @@ AGGREGATORS.forEach(([name, Aggregator]) =>
id,
[agg, WeightsContainer.of([ws])],
]),
),
), 0
)
)
.valueSeq()
Expand Down Expand Up @@ -94,6 +97,7 @@ export function setupNetwork<A extends Aggregator>(
// run all rounds of communication
export async function communicate<A extends Aggregator>(
networkWithContributions: Map<NodeID, [A, WeightsContainer]>,
aggregationRound: number
): Promise<Map<NodeID, WeightsContainer>> {
const communicationsRound =
networkWithContributions.first()?.[0].communicationRounds;
Expand Down Expand Up @@ -125,7 +129,7 @@ export async function communicate<A extends Aggregator>(
agg
.makePayloads(contrib)
.entrySeq()
.forEach(([to, payload]) => network.get(to)?.add(id, payload, 0, r)),
.forEach(([to, payload]) => network.get(to)?.add(id, payload, aggregationRound, r)),
);

contributions = Map(await Promise.all(nextContributions));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import createDebug from "debug";
import { Map, Set } from 'immutable'

import type { client } from '../index.js'
import type { client, WeightsContainer } from '../index.js'

import { EventEmitter } from '../utils/event_emitter.js'

Expand All @@ -17,10 +17,10 @@ export enum AggregationStep {
* Main, abstract, aggregator class whose role is to buffer contributions and to produce
* a result based off their aggregation, whenever some defined condition is met.
*
* Emits an event whenever an aggregation step is performed.
* Users wait for this event to fetch the aggregation result.
* Emits an event whenever an aggregation step is performed with the counrd's aggregated weights.
* Users subscribes to this event to get the aggregation result.
*/
export abstract class Base<T> extends EventEmitter<{'aggregation': T }> {
export abstract class Aggregator extends EventEmitter<{'aggregation': WeightsContainer }> {
/**
* Contains the ids of all active nodes, i.e. members of the aggregation group at
* a given round. It is a subset of all the nodes available in the network.
Expand All @@ -31,8 +31,8 @@ export abstract class Base<T> extends EventEmitter<{'aggregation': T }> {
* It defines the effective aggregation group, which is possibly a subset
* of all active nodes, depending on the aggregation scheme.
*/
// communication round -> NodeID -> T
protected contributions: Map<number, Map<client.NodeID, T>>
// communication round -> NodeID -> WeightsContainer
protected contributions: Map<number, Map<client.NodeID, WeightsContainer>>

/**
* The current aggregation round, used for assessing whether a node contribution is recent enough
Expand Down Expand Up @@ -61,29 +61,63 @@ export abstract class Base<T> extends EventEmitter<{'aggregation': T }> {

this.contributions = Map()
this._nodes = Set()
}

// On every aggregation, update the object's state to match the current aggregation
// and communication rounds.
this.on('aggregation', () => this.nextRound())
/**
* Convenience method to subscribe to the 'aggregation' event.
* Await this promise returns the aggregated weights for the current round.
*
* @returns a promise for the aggregated weights
*/
getPromiseForAggregation(): Promise<WeightsContainer> {
return new Promise<WeightsContainer>((resolve) => this.once('aggregation', resolve));
}

/**
* Adds a node's contribution to the aggregator for the given aggregation and communication rounds.
* The aggregation round is increased whenever a new global model is obtained and local models are updated.
* Within one aggregation round there may be multiple communication rounds (such as for the decentralized secure aggregation
* which requires multiple steps to obtain a global model)
* The contribution will be aggregated during the next aggregation step.
* which requires multiple steps to obtain a global model)
* The contribution is aggregated during the next aggregation step.
*
* @param nodeId The node's id
* @param contribution The node's contribution
* @param round aggregation round of the contribution was made
* @param communicationRound communication round the contribution was made within the aggregation round
* @returns boolean, true if the contribution has been successfully taken into account or False if it has been rejected
*/
abstract add (nodeId: client.NodeID, contribution: T, round: number, communicationRound?: number): boolean
add(nodeId: client.NodeID, contribution: WeightsContainer,
aggregationRound: number, communicationRound?: number): void {
if (!this.isValidContribution(nodeId, aggregationRound))
throw new Error("Tried adding an invalid contribution. Handle this case before calling add.")

// call the abstract method _add, implemented by subclasses
this._add(nodeId, contribution, communicationRound)
// If the aggregator has enough contributions then aggregate the weights
// and emit the 'aggregation' event
if (this.isFull()) {
const aggregatedWeights = this.aggregate()
// On each aggregation, increment the communication round
// If all communication rounds were performed, proceed to the next aggregation round
// and empty the past contributions.
this._communicationRound++;
if (this.communicationRound === this.communicationRounds) {
this._communicationRound = 0
this._round++;
this.contributions = Map()
}
// Emitting the 'aggregation' communicates the weights to subscribers
this.emit('aggregation', aggregatedWeights)
}
}

// Abstract method to be implemented by subclasses
// Handles logging and adding the contribution to the list of the current round's contributions
protected abstract _add(nodeId: client.NodeID, contribution: WeightsContainer, communicationRound?: number): void

/**
* Evaluates whether a given participant contribution can be used in the current aggregation round
* the boolean returned by `this.add` is obtained via `this.isValidContribution`
*
* @param nodeId the node id of the contribution to be added
* @param round the aggregation round of the contribution to be added
*/
isValidContribution(nodeId: client.NodeID, round: number): boolean {
if (!this.nodes.has(nodeId)) {
Expand All @@ -101,15 +135,15 @@ export abstract class Base<T> extends EventEmitter<{'aggregation': T }> {
* Performs an aggregation step over the received node contributions.
* Must store the aggregation's result in the aggregator's result promise.
*/
abstract aggregate (): void
protected abstract aggregate (): WeightsContainer

/**
* Returns whether the given round is recent enough, dependent on the
* aggregator's round cutoff.
* @param round The round
* @returns True if the round is recent enough, false otherwise
*/
isWithinRoundCutoff (round: number): boolean {
private isWithinRoundCutoff (round: number): boolean {
return this.round - round <= this.roundCutoff
}

Expand Down Expand Up @@ -172,14 +206,6 @@ export abstract class Base<T> extends EventEmitter<{'aggregation': T }> {
this._nodes = nodeIds
}

/**
* Empties the current set of "nodes". Usually called at the end of an aggregation round,
* if the set of nodes is meant to change or to be actualized.
*/
resetNodes (): void {
this._nodes = Set()
}

/**
* Sets the aggregator's round number. To be used whenever the aggregator is out of sync
* with the network's round.
Expand All @@ -191,24 +217,11 @@ export abstract class Base<T> extends EventEmitter<{'aggregation': T }> {
}
}

/**
* Updates the aggregator's state to proceed to the next communication round.
* If all communication rounds were performed, proceeds to the next aggregation round
* and empties the collection of stored contributions.
*/
public nextRound (): void {
if (++this._communicationRound === this.communicationRounds) {
this._communicationRound = 0
this._round++
this.contributions = Map()
}
}

/**
* Constructs the payloads sent to other nodes as contribution.
* @param base Object from which the payload is computed
*/
abstract makePayloads (base: T): Map<client.NodeID, T>
abstract makePayloads (base: WeightsContainer): Map<client.NodeID, WeightsContainer>

abstract isFull (): boolean

Expand All @@ -226,17 +239,6 @@ export abstract class Base<T> extends EventEmitter<{'aggregation': T }> {
return this._round
}

/**
* The aggregator's current size, defined by its number of contributions. The size is bounded by
* the amount of all active nodes times the number of communication rounds.
*/
get size (): number {
return this.contributions
.valueSeq()
.map((m) => m.size)
.reduce((totalSize: number, size) => totalSize + size) ?? 0
}

/**
* The current communication round.
*/
Expand Down
8 changes: 4 additions & 4 deletions discojs/src/aggregator/get.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ type AggregatorOptions = Partial<{
/**
* Initializes an aggregator according to the task definition, the training scheme and the aggregator parameters.
* Here is the ordered list of parameters used to define the aggregator and its default behavior:
* task.trainingInformation.aggregator > options.scheme > task.trainingInformation.scheme
* task.trainingInformation.aggregationStrategy > options.scheme > task.trainingInformation.scheme
*
* If `task.trainingInformation.aggregator` is defined, we initialize the chosen aggregator with `options` parameter values.
* If `task.trainingInformation.aggregationStrategy` is defined, we initialize the chosen aggregator with `options` parameter values.
* Otherwise, we default to a MeanAggregator for both training schemes.
*
* For the MeanAggregator we rely on `options.scheme` and fallback on `task.trainingInformation.scheme` to infer default values.
Expand All @@ -27,10 +27,10 @@ type AggregatorOptions = Partial<{
* @returns The aggregator
*/
export function getAggregator(task: Task, options: AggregatorOptions = {}): aggregator.Aggregator {
const aggregatorType = task.trainingInformation.aggregator ?? 'mean'
const aggregationStrategy = task.trainingInformation.aggregationStrategy ?? 'mean'
const scheme = options.scheme ?? task.trainingInformation.scheme

switch (aggregatorType) {
switch (aggregationStrategy) {
case 'mean':
if (scheme === 'decentralized') {
// If options are not specified, we default to expecting a contribution from all peers, so we set the threshold to 100%
Expand Down
9 changes: 2 additions & 7 deletions discojs/src/aggregator/index.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
import type { WeightsContainer } from '../weights/index.js'
import type { Base } from './base.js'

export { Base as AggregatorBase, AggregationStep } from './base.js'
export { Aggregator, AggregationStep } from './aggregator.js'
export { MeanAggregator } from './mean.js'
export { SecureAggregator } from './secure.js'

export { getAggregator } from './get.js'

export type Aggregator = Base<WeightsContainer>
export { getAggregator } from './get.js'
Loading