Skip to content

Commit

Permalink
*: Make aggregator.add return void to enforce listening to the 'agg…
Browse files Browse the repository at this point in the history
…regation' event to get the aggregated weights
  • Loading branch information
JulienVig committed Oct 9, 2024
1 parent f6b12d4 commit 17ed69d
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 82 deletions.
6 changes: 4 additions & 2 deletions discojs/src/aggregator.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ AGGREGATORS.forEach(([name, Aggregator]) =>

let promises = List<Promise<WeightsContainer>>()
for (let i = 0; i < 3; i++)
for (let r = 0; r < aggregator.communicationRounds; r++)
promises = promises.push(aggregator.add(`client ${i}`, WeightsContainer.of([i]), 0, r))
for (let r = 0; r < aggregator.communicationRounds; r++){
promises = promises.push(aggregator.getPromiseForAggregation())
aggregator.add(`client ${i}`, WeightsContainer.of([i]), 0, r)
}
await Promise.all(promises)
await results; // nothing to test

Expand Down
73 changes: 30 additions & 43 deletions discojs/src/aggregator/aggregator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ export enum AggregationStep {
* Main, abstract, aggregator class whose role is to buffer contributions and to produce
* a result based off their aggregation, whenever some defined condition is met.
*
* Emits an event whenever an aggregation step is performed.
* Users wait for this event to fetch the aggregation result.
* Emits an event whenever an aggregation step is performed with the counrd's aggregated weights.
* Users subscribes to this event to get the aggregation result.
*/
export abstract class Aggregator extends EventEmitter<{'aggregation': WeightsContainer }> {
/**
Expand Down Expand Up @@ -61,69 +61,56 @@ export abstract class Aggregator extends EventEmitter<{'aggregation': WeightsCon

this.contributions = Map()
this._nodes = Set()
}

// On each aggregation, increment
// updates the aggregator's state to proceed to the next communication round.
// If all communication rounds were performed, proceeds to the next aggregation round
// and empties the collection of stored contributions.
this.on('aggregation', () => {
this._communicationRound++;
if (this.communicationRound === this.communicationRounds) {
this._communicationRound = 0
this._round++
this.contributions = Map()
}
})
/**
* Convenience method to subscribe to the 'aggregation' event.
* Await this promise returns the aggregated weights for the current round.
*
* @returns a promise for the aggregated weights
*/
getPromiseForAggregation(): Promise<WeightsContainer> {
return new Promise<WeightsContainer>((resolve) => this.once('aggregation', resolve));
}

/**
* Adds a node's contribution to the aggregator for the given aggregation and communication rounds.
* The aggregation round is increased whenever a new global model is obtained and local models are updated.
* Within one aggregation round there may be multiple communication rounds (such as for the decentralized secure aggregation
* which requires multiple steps to obtain a global model)
* The contribution will be aggregated during the next aggregation step.
* The contribution is aggregated during the next aggregation step.
*
* @param nodeId The node's id
* @param contribution The node's contribution
* @returns a promise for the aggregated weights, or undefined if the contribution is invalid
*/
async add(nodeId: client.NodeID, contribution: WeightsContainer,
aggregationRound: number, communicationRound?: number): Promise<WeightsContainer> {
add(nodeId: client.NodeID, contribution: WeightsContainer,
aggregationRound: number, communicationRound?: number): void {
if (!this.isValidContribution(nodeId, aggregationRound))
throw new Error("Tried adding an invalid contribution. Handle this case before calling add.")

// call the abstract method _add, implemented by subclasses
this._add(nodeId, contribution, communicationRound)
return this.createAggregationPromise()
}

// Abstract method to be implemented by subclasses
// Handles logging and adding the contribution to the list of the current round's contributions
protected abstract _add(nodeId: client.NodeID, contribution: WeightsContainer, communicationRound?: number): void

/**
* Create a promise which resolves when enough contributions are received and
* local updates are aggregated.
* If the aggregator has enough contribution then we can aggregate the weights
* directly (and emit the 'aggregation' event)
* Otherwise we wait for the 'aggregation' event which will be emitted once
* enough contributions are received
*
* @returns a promise for the aggregated weights
*/
protected createAggregationPromise(): Promise<WeightsContainer> {
// Wait for the aggregation event to be emitted
const ret = new Promise<WeightsContainer>((resolve) => this.once('aggregation', resolve));

// If the aggregator has enough contributions then aggregate the weights
// and emit the 'aggregation' event
if (this.isFull()) {
const aggregatedWeights = this.aggregate()
// Emitting the 'aggregation' communicates the aggregation to other clients and
// takes care of incrementing the round
// On each aggregation, increment the communication round
// If all communication rounds were performed, proceed to the next aggregation round
// and empty the past contributions.
this._communicationRound++;
if (this.communicationRound === this.communicationRounds) {
this._communicationRound = 0
this._round++;
this.contributions = Map()
}
// Emitting the 'aggregation' communicates the weights to subscribers
this.emit('aggregation', aggregatedWeights)
}

return ret
}

// Abstract method to be implemented by subclasses
// Handles logging and adding the contribution to the list of the current round's contributions
protected abstract _add(nodeId: client.NodeID, contribution: WeightsContainer, communicationRound?: number): void

/**
* Evaluates whether a given participant contribution can be used in the current aggregation round
Expand Down
49 changes: 30 additions & 19 deletions discojs/src/aggregator/mean.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,29 @@ describe("mean aggregator", () => {
// round 0
expect(aggregator.round).to.equal(0)
expect(aggregator.isValidContribution("client 1", 0)).to.be.true;
const client1Round0Promise = await aggregator.add("client 1", WeightsContainer.of([1]), 0);
expect(WeightsContainer.of([1]).equals(client1Round0Promise)).to.be.true
const client1Round0Promise = aggregator.getPromiseForAggregation();
aggregator.add("client 1", WeightsContainer.of([1]), 0);
expect(WeightsContainer.of([1]).equals(await client1Round0Promise)).to.be.true
expect(aggregator.round).to.equal(1)

// round 1
aggregator.registerNode("client 2");
expect(aggregator.isValidContribution("client 2", 0)).to.be.true; // round 0 should be within the cutoff
void aggregator.add("client 1", WeightsContainer.of([1]), 1);
const client2Round0Promise = await aggregator.add("client 2", WeightsContainer.of([2]), 0);
expect(WeightsContainer.of([1.5]).equals(client2Round0Promise)).to.be.true
aggregator.add("client 1", WeightsContainer.of([1]), 1);
const client2Round0Promise = aggregator.getPromiseForAggregation();
aggregator.add("client 2", WeightsContainer.of([2]), 0);
expect(WeightsContainer.of([1.5]).equals(await client2Round0Promise)).to.be.true
expect(aggregator.round).to.equal(2)

// round 2
aggregator.registerNode("client 3");
expect(aggregator.isValidContribution("client 3", 0)).to.be.false; // round 0 is now out of the cutoff
expect(aggregator.isValidContribution("client 3", 1)).to.be.true;
void aggregator.add("client 1", WeightsContainer.of([1]), 2);
void aggregator.add("client 2", WeightsContainer.of([1]), 2);
const client3Round2Promise = await aggregator.add("client 3", WeightsContainer.of([4]), 1);
expect(WeightsContainer.of([2]).equals(client3Round2Promise)).to.be.true
aggregator.add("client 1", WeightsContainer.of([1]), 2);
aggregator.add("client 2", WeightsContainer.of([1]), 2);
const client3Round2Promise = aggregator.getPromiseForAggregation();
aggregator.add("client 3", WeightsContainer.of([4]), 1);
expect(WeightsContainer.of([2]).equals(await client3Round2Promise)).to.be.true
expect(aggregator.round).to.equal(3)
});

Expand All @@ -51,8 +54,10 @@ describe("mean aggregator", () => {
aggregator.once("aggregation", resolve),
);

const result1 = aggregator.add(id1, WeightsContainer.of([0], [1]), 0);
const result2 = aggregator.add(id2, WeightsContainer.of([2], [3]), 0);
const result1 = aggregator.getPromiseForAggregation();
aggregator.add(id1, WeightsContainer.of([0], [1]), 0);
const result2 = aggregator.getPromiseForAggregation();
aggregator.add(id2, WeightsContainer.of([2], [3]), 0);
expect((await result1).equals(await result2)).to.be.true

expect(await WSIntoArrays(await results)).to.deep.equal([[1], [2]]);
Expand All @@ -64,12 +69,14 @@ describe("mean aggregator", () => {

aggregator.setNodes(Set.of(id1, id2));

const result1 = aggregator.add(id1, WeightsContainer.of([0], [1]), 0);
const result1 = aggregator.getPromiseForAggregation();
aggregator.add(id1, WeightsContainer.of([0], [1]), 0);
// Make sure that the aggregation isn't triggered
expect(aggregator.round).equals(0)

aggregator.registerNode(id2);
const result2 = aggregator.add(id2, WeightsContainer.of([2], [3]), 0);
const result2 = aggregator.getPromiseForAggregation();
aggregator.add(id2, WeightsContainer.of([2], [3]), 0);
expect((await result1).equals(await result2)).to.be.true
expect(aggregator.round).equals(1) // round should be one now
});
Expand All @@ -80,8 +87,9 @@ describe("mean aggregator", () => {
aggregator.setNodes(Set.of(id1, id2)); // register two clients

// should aggregate with only one contribution
const result = await aggregator.add(id1, WeightsContainer.of([0], [1]), 0);
expect(await WSIntoArrays(result)).to.deep.equal([[0], [1]]);
const result = aggregator.getPromiseForAggregation();
aggregator.add(id1, WeightsContainer.of([0], [1]), 0);
expect(await WSIntoArrays(await result)).to.deep.equal([[0], [1]]);
});

it("can wait for an relative number of contributions", async () => {
Expand All @@ -90,8 +98,9 @@ describe("mean aggregator", () => {
aggregator.setNodes(Set.of(id1, id2)); // register two clients

// should aggregate with only 50% of the contribution (1 contribution)
const result = await aggregator.add(id1, WeightsContainer.of([0], [1]), 0);
expect(await WSIntoArrays(result)).to.deep.equal([[0], [1]]);
const result = aggregator.getPromiseForAggregation();
aggregator.add(id1, WeightsContainer.of([0], [1]), 0);
expect(await WSIntoArrays(await result)).to.deep.equal([[0], [1]]);
});

it("doesn't aggregate when not enough participants", async () => {
Expand All @@ -100,12 +109,14 @@ describe("mean aggregator", () => {
const [id1, id2] = ["client 1", "client 2"]
aggregator.setNodes(Set.of(id1));

const result1 = aggregator.add(id1, WeightsContainer.of([0], [1]), 0);
const result1 = aggregator.getPromiseForAggregation();
aggregator.add(id1, WeightsContainer.of([0], [1]), 0);
// Make sure that the aggregation isn't triggered
expect(aggregator.round).equals(0)

aggregator.registerNode(id2);
const result2 = aggregator.add(id2, WeightsContainer.of([2], [3]), 0);
const result2 = aggregator.getPromiseForAggregation();
aggregator.add(id2, WeightsContainer.of([2], [3]), 0);
expect((await result1).equals(await result2)).to.be.true
expect(aggregator.round).equals(1)
});
Expand Down
11 changes: 5 additions & 6 deletions discojs/src/client/decentralized/decentralized_client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ export class DecentralizedClient extends Client {
this.server.send({ type: type.JoinRound })
// Store the promise for the current round's aggregation result.
// We will await for it to resolve at the end of the round when exchanging weight updates.
this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve))
this.aggregationResult = this.aggregator.getPromiseForAggregation()
this.saveAndEmit("local training")
return Promise.resolve()
}
Expand Down Expand Up @@ -223,12 +223,11 @@ export class DecentralizedClient extends Client {
else {
debug(`[${shortenId(this.ownId)}] received payload from peer ${shortenId(peerId)}` +
` for round (%d, %d)`, message.aggregationRound, message.communicationRound);
// Make sure to not await this promise in order to not miss subsequent messages
void this.aggregator.add(peerId, decoded, message.aggregationRound, message.communicationRound)
.then(() =>
this.aggregator.once("aggregation", () =>
debug(`[${shortenId(this.ownId)}] aggregated the model` +
` for round (%d, %d)`, message.aggregationRound, message.communicationRound)
)
this.aggregator.add(peerId, decoded, message.aggregationRound, message.communicationRound)
}
} catch (e) {
if (this.isDisconnected) return
Expand Down Expand Up @@ -257,7 +256,7 @@ export class DecentralizedClient extends Client {
payloads.forEach(async (payload, id) => {
// add our own contribution to the aggregator
if (id === this.ownId) {
void this.aggregator.add(this.ownId, payload, communicationRound)
this.aggregator.add(this.ownId, payload, this.aggregator.round, communicationRound)
return
}
// Send our payload to each peer
Expand Down Expand Up @@ -294,7 +293,7 @@ export class DecentralizedClient extends Client {
// There is at least one communication round remaining
if (communicationRound < this.aggregator.communicationRounds - 1) {
// Reuse the aggregation result
this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve))
this.aggregationResult = this.aggregator.getPromiseForAggregation()
}
}
return await this.aggregationResult
Expand Down
25 changes: 13 additions & 12 deletions server/src/controllers/federated_controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,19 +101,20 @@ export class FederatedController extends TrainingController {
if (this.#aggregator.isValidContribution(clientId, round)) {
const weights = serialization.weights.decode(payload)

// Send the aggregated weight to the client when enough contributions are received
// Create a callback to send the aggregated weight to the client
// when enough contributions are received
this.#aggregator.once('aggregation', async (weightUpdate) => {
debug("Sending global weights for round %o to client [%s]", this.#aggregator.round, shortId)
const msg: FederatedMessages.ReceiveServerPayload = {
type: MessageTypes.ReceiveServerPayload,
round: this.#aggregator.round, // send the current round number after aggregation
payload: await serialization.weights.encode(weightUpdate),
nbOfParticipants: this.connections.size
}
ws.send(msgpack.encode(msg))
})
// Add the contribution
this.#aggregator.add(clientId, weights, round)
.then(async (weightUpdate) => {
debug("Sending global weights for round %o to client [%s]", this.#aggregator.round, shortId)
const msg: FederatedMessages.ReceiveServerPayload = {
type: MessageTypes.ReceiveServerPayload,
round: this.#aggregator.round, // send the current round number after aggregation
payload: await serialization.weights.encode(weightUpdate),
nbOfParticipants: this.connections.size
}
ws.send(msgpack.encode(msg))
})
.catch((e) => debug("while waiting for weights: %o", e))
debug(`Successfully added contribution from client [%s] for round ${round}`, shortId)
} else {
// If the client sent an invalid or outdated contribution
Expand Down

0 comments on commit 17ed69d

Please sign in to comment.