Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GDHF demo #840

Merged
merged 14 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions cli/src/args.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ interface BenchmarkArguments {
roundDuration: number
batchSize: number
save: boolean
host: URL
}

type BenchmarkUnsafeArguments = Omit<BenchmarkArguments, 'provider'> & {
Expand All @@ -22,12 +23,19 @@ const argExample = 'e.g. npm start -- -u 2 -e 3 # runs 2 users for 3 epochs'

const unsafeArgs = parse<BenchmarkUnsafeArguments>(
{
task: { type: String, alias: 't', description: 'Task: titanic, simple_face, cifar10 or lus_covid', defaultValue: 'simple_face' },
numberOfUsers: { type: Number, alias: 'u', description: 'Number of users', defaultValue: 1 },
task: { type: String, alias: 't', description: 'Task: tinder_dog, titanic, simple_face, cifar10 or lus_covid', defaultValue: 'tinder_dog' },
numberOfUsers: { type: Number, alias: 'u', description: 'Number of users', defaultValue: 2 },
epochs: { type: Number, alias: 'e', description: 'Number of epochs', defaultValue: 10 },
roundDuration: { type: Number, alias: 'r', description: 'Round duration (in epochs)', defaultValue: 2 },
batchSize: { type: Number, alias: 'b', description: 'Training batch size', defaultValue: 10 },
save: { type: Boolean, alias: 's', description: 'Save logs of benchmark', defaultValue: false },
host: {
type: (raw: string) => new URL(raw),
typeLabel: "URL",
description: "Host to connect to",
defaultValue: new URL("http://localhost:8080"),
},

help: { type: Boolean, optional: true, alias: 'h', description: 'Prints this usage guide' }
},
{
Expand All @@ -42,6 +50,7 @@ const supportedTasks = Map(
defaultTasks.lusCovid,
defaultTasks.simpleFace,
defaultTasks.titanic,
defaultTasks.tinderDog,
).map((t) => [t.getTask().id, t]),
);

Expand Down Expand Up @@ -69,4 +78,4 @@ export const args: BenchmarkArguments = {
},
getModel: () => provider.getModel(),
},
};
};
17 changes: 5 additions & 12 deletions cli/src/cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import type {
TaskProvider,
} from "@epfml/discojs";
import { Disco, aggregator as aggregators, client as clients } from '@epfml/discojs'
import { Server } from 'server'

import { getTaskData } from './data.js'
import { args } from './args.js'
Expand Down Expand Up @@ -49,23 +48,17 @@ async function main<D extends DataType>(
console.log(`Started ${task.trainingInformation.scheme} training of ${task.id}`)
console.log({ args })

const [server, url] = await new Server().serve(undefined, provider)

const data = await getTaskData(task)

const dataSplits = await Promise.all(
Range(0, numberOfUsers).map(async i => getTaskData(task.id, i))
)
const logs = await Promise.all(
Range(0, numberOfUsers).map(async (_) => await runUser(task, url, data)).toArray()
dataSplits.map(async data => await runUser(task, args.host, data as Dataset<DataFormat.Raw[D]>))
)

if (args.save) {
const fileName = `${task.id}_${numberOfUsers}users.csv`;
await fs.writeFile(fileName, JSON.stringify(logs, null, 2));
}
console.log('Shutting down the server...')
await new Promise((resolve, reject) => {
server.once('close', resolve)
server.close(reject)
})
}

main(args.provider, args.numberOfUsers).catch(console.error)
main(args.provider, args.numberOfUsers).catch(console.error)
39 changes: 32 additions & 7 deletions cli/src/data.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import path from "node:path";

import { Dataset, processing } from "@epfml/discojs";
import type {
Dataset,
DataFormat,
DataType,
Image,
Task,
} from "@epfml/discojs";
import { loadCSV, loadImagesInDir } from "@epfml/discojs-node";
import { loadCSV, loadImage, loadImagesInDir } from "@epfml/discojs-node";
import { Repeat } from "immutable";

async function loadSimpleFaceData(): Promise<Dataset<DataFormat.Raw["image"]>> {
Expand Down Expand Up @@ -36,10 +35,34 @@ async function loadLusCovidData(): Promise<Dataset<DataFormat.Raw["image"]>> {
return positive.chain(negative);
}

function loadTinderDogData(split: number): Dataset<DataFormat.Raw["image"]> {
const folder = path.join("..", "datasets", "tinder_dog", `${split + 1}`);
return loadCSV(path.join(folder, "labels.csv"))
.map(
(row) =>
[
processing.extractColumn(row, "filename"),
processing.extractColumn(row, "label"),
] as const,
)
.map(async ([filename, label]) => {
try {
const image = await Promise.any(
["png", "jpg", "jpeg"].map((ext) =>
loadImage(path.join(folder, `${filename}.${ext}`)),
),
);
return [image, label];
} catch {
throw Error(`${filename} not found in ${folder}`);
}
});
}

export async function getTaskData<D extends DataType>(
task: Task<D>,
taskID: Task<D>['id'], userIdx: number
): Promise<Dataset<DataFormat.Raw[D]>> {
switch (task.id) {
switch (taskID) {
case "simple_face":
return (await loadSimpleFaceData()) as Dataset<DataFormat.Raw[D]>;
case "titanic":
Expand All @@ -52,7 +75,9 @@ export async function getTaskData<D extends DataType>(
).zip(Repeat("cat")) as Dataset<DataFormat.Raw[D]>;
case "lus_covid":
return (await loadLusCovidData()) as Dataset<DataFormat.Raw[D]>;
case "tinder_dog":
return loadTinderDogData(userIdx) as Dataset<DataFormat.Raw[D]>;
default:
throw new Error(`Data loader for ${task.id} not implemented.`);
throw new Error(`Data loader for ${taskID} not implemented.`);
}
}
}
3 changes: 3 additions & 0 deletions datasets/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@

# LUS Covid
/lus_covid/

# GDHF demo
/tinder_dog/
JulienVig marked this conversation as resolved.
Show resolved Hide resolved
5 changes: 5 additions & 0 deletions datasets/populate
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,8 @@ rm archive.zip DeAI-testimages
mkdir -p wikitext
curl 'https://dax-cdn.cdn.appdomain.cloud/dax-wikitext-103/1.0.1/wikitext-103.tar.gz' |
tar --extract --gzip --strip-components=1 -C wikitext

# tinder_dog
curl 'https://storage.googleapis.com/deai-313515.appspot.com/tinder_dog.zip' > tinder_dog.zip
unzip -u tinder_dog.zip
rm tinder_dog.zip
2 changes: 2 additions & 0 deletions discojs/src/client/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ export abstract class Client extends EventEmitter<{'status': RoundStatus}>{
url.pathname += `tasks/${this.task.id}/model.json`

const response = await fetch(url);
if (!response.ok) throw new Error(`fetch: HTTP status ${response.status}`);

const encoded = new Uint8Array(await response.arrayBuffer())
return await serialization.model.decode(encoded)
}
Expand Down
3 changes: 1 addition & 2 deletions discojs/src/client/federated/federated_client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import { Client, shortenId } from "../client.js";
import { type, type ClientConnected } from "../messages.js";
import {
waitMessage,
waitMessageWithTimeout,
WebSocketServer,
} from "../event_connection.js";
import * as messages from "./messages.js";
Expand Down Expand Up @@ -75,7 +74,7 @@ export class FederatedClient extends Client {
const {
id, waitForMoreParticipants, payload,
round, nbOfParticipants
} = await waitMessageWithTimeout(this.server, type.NewFederatedNodeInfo);
} = await waitMessage(this.server, type.NewFederatedNodeInfo);

// This should come right after receiving the message to make sure
// we don't miss a subsequent message from the server
Expand Down
1 change: 1 addition & 0 deletions discojs/src/default_tasks/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ export { mnist } from './mnist.js'
export { simpleFace } from './simple_face.js'
export { titanic } from './titanic.js'
export { wikitext } from './wikitext.js'
export { tinderDog } from './tinder_dog.js'
84 changes: 84 additions & 0 deletions discojs/src/default_tasks/tinder_dog.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import * as tf from '@tensorflow/tfjs'

import type { Model, Task, TaskProvider } from '../index.js'
import { models } from '../index.js'

export const tinderDog: TaskProvider<'image'> = {
getTask (): Task<'image'> {
return {
id: 'tinder_dog',
displayInformation: {
taskTitle: 'GDHF 2024 | TinderDog',
summary: {
preview: 'Which dog is the cutest....or not?',
overview: "Binary classification model for dog cuteness."
},
model: 'The model is a simple Convolutional Neural Network composed of two convolutional layers with ReLU activations and max pooling layers, followed by a fully connected output layer. The data preprocessing reshapes images into 64x64 pixels and normalizes values between 0 and 1',
dataFormatInformation: 'Accepted image formats are .png .jpg and .jpeg.',
dataExampleText: '',
dataExampleImage: 'https://storage.googleapis.com/deai-313515.appspot.com/tinder_dog_preview.png',
sampleDatasetLink: 'https://storage.googleapis.com/deai-313515.appspot.com/tinder_dog.zip',
sampleDatasetInstructions: 'Opening the link should start downloading a zip file which you can unzip. To connect the data, pick one of the data splits (the folder 0 for example) and use the CSV option below to select the file named "labels.csv". You can now connect the images located in the same folder.'
},
trainingInformation: {
epochs: 10,
roundDuration: 2,
validationSplit: 0, // nicer plot for GDHF demo
batchSize: 10,
dataType: 'image',
IMAGE_H: 64,
IMAGE_W: 64,
LABEL_LIST: ['Cute dogs', 'Less cute dogs'],
scheme: 'federated',
aggregationStrategy: 'mean',
minNbOfParticipants: 3,
tensorBackend: 'tfjs'
}
}
},


async getModel(): Promise<Model<'image'>> {
const seed = 42 // set a seed to ensure reproducibility during GDHF demo
const imageHeight = this.getTask().trainingInformation.IMAGE_H
const imageWidth = this.getTask().trainingInformation.IMAGE_W
const imageChannels = 3

const model = tf.sequential()

model.add(
tf.layers.conv2d({
inputShape: [imageHeight, imageWidth, imageChannels],
kernelSize: 5,
filters: 8,
activation: 'relu',
kernelInitializer: tf.initializers.heNormal({ seed })
})
)
model.add(tf.layers.conv2d({
kernelSize: 5, filters: 16, activation: 'relu',
kernelInitializer: tf.initializers.heNormal({ seed })
}))
model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }))
model.add(tf.layers.dropout({ rate: 0.25, seed }))

model.add(tf.layers.flatten())
model.add(tf.layers.dense({
units: 32, activation: 'relu',
kernelInitializer: tf.initializers.heNormal({ seed })
}))
model.add(tf.layers.dropout({rate:0.25, seed}))
model.add(tf.layers.dense({
units: 2, activation: 'softmax',
kernelInitializer: tf.initializers.heNormal({ seed })
}))

model.compile({
optimizer: tf.train.adam(0.0005),
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
})

return Promise.resolve(new models.TFJS('image', model))
}
}
4 changes: 3 additions & 1 deletion discojs/src/task/task_handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,22 @@ export async function pushTask<D extends DataType>(
task: Task<D>,
model: Model<D>,
): Promise<void> {
await fetch(urlToTasks(base), {
const response = await fetch(urlToTasks(base), {
method: "POST",
body: JSON.stringify({
task,
model: await serialization.model.encode(model),
weights: await serialization.weights.encode(model.weights),
}),
});
if (!response.ok) throw new Error(`fetch: HTTP status ${response.status}`);
}

export async function fetchTasks(
base: URL,
): Promise<Map<TaskID, Task<DataType>>> {
const response = await fetch(urlToTasks(base));
if (!response.ok) throw new Error(`fetch: HTTP status ${response.status}`);
const tasks: unknown = await response.json();

if (!Array.isArray(tasks)) {
Expand Down
13 changes: 8 additions & 5 deletions discojs/src/training/disco.ts
Original file line number Diff line number Diff line change
Expand Up @@ -205,18 +205,21 @@ export class Disco<D extends DataType> extends EventEmitter<{
): Promise<
[
Dataset<Batched<DataFormat.ModelEncoded[D]>>,
Dataset<Batched<DataFormat.ModelEncoded[D]>>,
Dataset<Batched<DataFormat.ModelEncoded[D]>> | undefined,
]
> {
const { batchSize, validationSplit } = this.#task.trainingInformation;

const preprocessed = await processing.preprocess(this.#task, dataset);
let preprocessed = await processing.preprocess(this.#task, dataset);

const [training, validation] = (
preprocessed = (
this.#preprocessOnce
? new Dataset(await arrayFromAsync(preprocessed))
: preprocessed
).split(validationSplit);
)
if (validationSplit === 0) return [preprocessed.batch(batchSize).cached(), undefined];

const [training, validation] = preprocessed.split(validationSplit);

return [
training.batch(batchSize).cached(),
Expand All @@ -230,4 +233,4 @@ async function arrayFromAsync<T>(iter: AsyncIterable<T>): Promise<T[]> {
const ret: T[] = [];
for await (const e of iter) ret.push(e);
return ret;
}
}
11 changes: 9 additions & 2 deletions server/src/controllers/federated_controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ export class FederatedController<
*/
#latestGlobalWeights: serialization.Encoded;

constructor(task: Task<D>, initialWeights: serialization.Encoded) {
constructor(task: Task<D>, private readonly initialWeights: serialization.Encoded) {
super(task)
this.#latestGlobalWeights = initialWeights
this.#latestGlobalWeights = this.initialWeights

// Save the latest weight updates to be able to send it to new or outdated clients
this.#aggregator.on('aggregation', async (weightUpdate) => {
Expand Down Expand Up @@ -145,6 +145,13 @@ export class FederatedController<
this.#aggregator.removeNode(clientId)
debug("client [%s] left", shortId)

// Reset the training session when all participants left
if (this.connections.size === 0) {
debug("All participants left. Resetting the training session")
this.#aggregator = new aggregators.MeanAggregator(undefined, 1, 'relative')
this.#latestGlobalWeights = this.initialWeights
}

// Check if we dropped below the minimum number of participant required
// or if we are already waiting for new participants to join
if (this.connections.size >= minNbOfParticipants ||
Expand Down
6 changes: 3 additions & 3 deletions webapp/src/components/containers/ImageCard.vue
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
<template>
<div
class="grid grid-cols-1 w-full bg-white dark:bg-slate-800 aspect-square rounded-xl drop-shadow-md hover:drop-shadow-xl transition duration-500 hover:scale-105 opacity-70 hover:opacity-100 shadow hover:shadow-lg"
class="grid grid-cols-1 w-full bg-white dark:bg-slate-800 aspect-square rounded-xl drop-shadow-md hover:drop-shadow-xl transition duration-500 hover:scale-125 shadow hover:shadow-lg hover:z-50"
>
<div class="grid grid-cols-1 gap-1 text-center content-center p-2 h-16">
<div class="grid grid-cols-1 gap-1 text-center content-center pt-2 h-8">
<slot name="title" />
<div class="text-sm">
<slot name="subtitle" />
Expand Down Expand Up @@ -35,4 +35,4 @@ function draw() {
if (context === null) throw new Error("canvas doesn't support 2D context");
context.putImageData(props.image, 0, 0);
}
</script>
</script>
Loading
Loading