diff --git a/cli/src/args.ts b/cli/src/args.ts index 74c4ed5c7..ab5e0d92c 100644 --- a/cli/src/args.ts +++ b/cli/src/args.ts @@ -11,6 +11,7 @@ interface BenchmarkArguments { roundDuration: number batchSize: number save: boolean + host: URL } type BenchmarkUnsafeArguments = Omit & { @@ -22,12 +23,19 @@ const argExample = 'e.g. npm start -- -u 2 -e 3 # runs 2 users for 3 epochs' const unsafeArgs = parse( { - task: { type: String, alias: 't', description: 'Task: titanic, simple_face, cifar10 or lus_covid', defaultValue: 'simple_face' }, - numberOfUsers: { type: Number, alias: 'u', description: 'Number of users', defaultValue: 1 }, + task: { type: String, alias: 't', description: 'Task: tinder_dog, titanic, simple_face, cifar10 or lus_covid', defaultValue: 'tinder_dog' }, + numberOfUsers: { type: Number, alias: 'u', description: 'Number of users', defaultValue: 2 }, epochs: { type: Number, alias: 'e', description: 'Number of epochs', defaultValue: 10 }, roundDuration: { type: Number, alias: 'r', description: 'Round duration (in epochs)', defaultValue: 2 }, batchSize: { type: Number, alias: 'b', description: 'Training batch size', defaultValue: 10 }, save: { type: Boolean, alias: 's', description: 'Save logs of benchmark', defaultValue: false }, + host: { + type: (raw: string) => new URL(raw), + typeLabel: "URL", + description: "Host to connect to", + defaultValue: new URL("http://localhost:8080"), + }, + help: { type: Boolean, optional: true, alias: 'h', description: 'Prints this usage guide' } }, { @@ -42,6 +50,7 @@ const supportedTasks = Map( defaultTasks.lusCovid, defaultTasks.simpleFace, defaultTasks.titanic, + defaultTasks.tinderDog, ).map((t) => [t.getTask().id, t]), ); @@ -69,4 +78,4 @@ export const args: BenchmarkArguments = { }, getModel: () => provider.getModel(), }, -}; +}; \ No newline at end of file diff --git a/cli/src/cli.ts b/cli/src/cli.ts index 54ff0c3ef..c56c9280b 100644 --- a/cli/src/cli.ts +++ b/cli/src/cli.ts @@ -13,7 +13,6 @@ import type { TaskProvider, } from "@epfml/discojs"; import { Disco, aggregator as aggregators, client as clients } from '@epfml/discojs' -import { Server } from 'server' import { getTaskData } from './data.js' import { args } from './args.js' @@ -49,23 +48,17 @@ async function main( console.log(`Started ${task.trainingInformation.scheme} training of ${task.id}`) console.log({ args }) - const [server, url] = await new Server().serve(undefined, provider) - - const data = await getTaskData(task) - + const dataSplits = await Promise.all( + Range(0, numberOfUsers).map(async i => getTaskData(task.id, i)) + ) const logs = await Promise.all( - Range(0, numberOfUsers).map(async (_) => await runUser(task, url, data)).toArray() + dataSplits.map(async data => await runUser(task, args.host, data as Dataset)) ) if (args.save) { const fileName = `${task.id}_${numberOfUsers}users.csv`; await fs.writeFile(fileName, JSON.stringify(logs, null, 2)); } - console.log('Shutting down the server...') - await new Promise((resolve, reject) => { - server.once('close', resolve) - server.close(reject) - }) } -main(args.provider, args.numberOfUsers).catch(console.error) +main(args.provider, args.numberOfUsers).catch(console.error) \ No newline at end of file diff --git a/cli/src/data.ts b/cli/src/data.ts index 895a35bf7..c3077c11b 100644 --- a/cli/src/data.ts +++ b/cli/src/data.ts @@ -1,13 +1,12 @@ import path from "node:path"; - +import { Dataset, processing } from "@epfml/discojs"; import type { - Dataset, DataFormat, DataType, Image, Task, } from "@epfml/discojs"; -import { loadCSV, loadImagesInDir } from "@epfml/discojs-node"; +import { loadCSV, loadImage, loadImagesInDir } from "@epfml/discojs-node"; import { Repeat } from "immutable"; async function loadSimpleFaceData(): Promise> { @@ -36,10 +35,34 @@ async function loadLusCovidData(): Promise> { return positive.chain(negative); } +function loadTinderDogData(split: number): Dataset { + const folder = path.join("..", "datasets", "tinder_dog", `${split + 1}`); + return loadCSV(path.join(folder, "labels.csv")) + .map( + (row) => + [ + processing.extractColumn(row, "filename"), + processing.extractColumn(row, "label"), + ] as const, + ) + .map(async ([filename, label]) => { + try { + const image = await Promise.any( + ["png", "jpg", "jpeg"].map((ext) => + loadImage(path.join(folder, `${filename}.${ext}`)), + ), + ); + return [image, label]; + } catch { + throw Error(`${filename} not found in ${folder}`); + } + }); +} + export async function getTaskData( - task: Task, + taskID: Task['id'], userIdx: number ): Promise> { - switch (task.id) { + switch (taskID) { case "simple_face": return (await loadSimpleFaceData()) as Dataset; case "titanic": @@ -52,7 +75,9 @@ export async function getTaskData( ).zip(Repeat("cat")) as Dataset; case "lus_covid": return (await loadLusCovidData()) as Dataset; + case "tinder_dog": + return loadTinderDogData(userIdx) as Dataset; default: - throw new Error(`Data loader for ${task.id} not implemented.`); + throw new Error(`Data loader for ${taskID} not implemented.`); } -} +} \ No newline at end of file diff --git a/datasets/.gitignore b/datasets/.gitignore index d1d80c705..f73eda468 100644 --- a/datasets/.gitignore +++ b/datasets/.gitignore @@ -17,3 +17,6 @@ # LUS Covid /lus_covid/ + +# GDHF demo +/tinder_dog/ diff --git a/datasets/populate b/datasets/populate index 3f86a82fa..105ea0710 100755 --- a/datasets/populate +++ b/datasets/populate @@ -19,3 +19,8 @@ rm archive.zip DeAI-testimages mkdir -p wikitext curl 'https://dax-cdn.cdn.appdomain.cloud/dax-wikitext-103/1.0.1/wikitext-103.tar.gz' | tar --extract --gzip --strip-components=1 -C wikitext + +# tinder_dog +curl 'https://storage.googleapis.com/deai-313515.appspot.com/tinder_dog.zip' > tinder_dog.zip +unzip -u tinder_dog.zip +rm tinder_dog.zip \ No newline at end of file diff --git a/discojs/src/client/client.ts b/discojs/src/client/client.ts index 9f05febc7..795f9346e 100644 --- a/discojs/src/client/client.ts +++ b/discojs/src/client/client.ts @@ -178,6 +178,8 @@ export abstract class Client extends EventEmitter<{'status': RoundStatus}>{ url.pathname += `tasks/${this.task.id}/model.json` const response = await fetch(url); + if (!response.ok) throw new Error(`fetch: HTTP status ${response.status}`); + const encoded = new Uint8Array(await response.arrayBuffer()) return await serialization.model.decode(encoded) } diff --git a/discojs/src/client/federated/federated_client.ts b/discojs/src/client/federated/federated_client.ts index e6c89cc79..fc172d7b8 100644 --- a/discojs/src/client/federated/federated_client.ts +++ b/discojs/src/client/federated/federated_client.ts @@ -6,7 +6,6 @@ import { Client, shortenId } from "../client.js"; import { type, type ClientConnected } from "../messages.js"; import { waitMessage, - waitMessageWithTimeout, WebSocketServer, } from "../event_connection.js"; import * as messages from "./messages.js"; @@ -75,7 +74,7 @@ export class FederatedClient extends Client { const { id, waitForMoreParticipants, payload, round, nbOfParticipants - } = await waitMessageWithTimeout(this.server, type.NewFederatedNodeInfo); + } = await waitMessage(this.server, type.NewFederatedNodeInfo); // This should come right after receiving the message to make sure // we don't miss a subsequent message from the server diff --git a/discojs/src/default_tasks/index.ts b/discojs/src/default_tasks/index.ts index 7ee583f1f..43adf0d3c 100644 --- a/discojs/src/default_tasks/index.ts +++ b/discojs/src/default_tasks/index.ts @@ -4,3 +4,4 @@ export { mnist } from './mnist.js' export { simpleFace } from './simple_face.js' export { titanic } from './titanic.js' export { wikitext } from './wikitext.js' +export { tinderDog } from './tinder_dog.js' \ No newline at end of file diff --git a/discojs/src/default_tasks/tinder_dog.ts b/discojs/src/default_tasks/tinder_dog.ts new file mode 100644 index 000000000..a19bf5f8b --- /dev/null +++ b/discojs/src/default_tasks/tinder_dog.ts @@ -0,0 +1,84 @@ +import * as tf from '@tensorflow/tfjs' + +import type { Model, Task, TaskProvider } from '../index.js' +import { models } from '../index.js' + +export const tinderDog: TaskProvider<'image'> = { + getTask (): Task<'image'> { + return { + id: 'tinder_dog', + displayInformation: { + taskTitle: 'GDHF 2024 | TinderDog', + summary: { + preview: 'Which dog is the cutest....or not?', + overview: "Binary classification model for dog cuteness." + }, + model: 'The model is a simple Convolutional Neural Network composed of two convolutional layers with ReLU activations and max pooling layers, followed by a fully connected output layer. The data preprocessing reshapes images into 64x64 pixels and normalizes values between 0 and 1', + dataFormatInformation: 'Accepted image formats are .png .jpg and .jpeg.', + dataExampleText: '', + dataExampleImage: 'https://storage.googleapis.com/deai-313515.appspot.com/tinder_dog_preview.png', + sampleDatasetLink: 'https://storage.googleapis.com/deai-313515.appspot.com/tinder_dog.zip', + sampleDatasetInstructions: 'Opening the link should start downloading a zip file which you can unzip. To connect the data, pick one of the data splits (the folder 0 for example) and use the CSV option below to select the file named "labels.csv". You can now connect the images located in the same folder.' + }, + trainingInformation: { + epochs: 10, + roundDuration: 2, + validationSplit: 0, // nicer plot for GDHF demo + batchSize: 10, + dataType: 'image', + IMAGE_H: 64, + IMAGE_W: 64, + LABEL_LIST: ['Cute dogs', 'Less cute dogs'], + scheme: 'federated', + aggregationStrategy: 'mean', + minNbOfParticipants: 3, + tensorBackend: 'tfjs' + } + } + }, + + + async getModel(): Promise> { + const seed = 42 // set a seed to ensure reproducibility during GDHF demo + const imageHeight = this.getTask().trainingInformation.IMAGE_H + const imageWidth = this.getTask().trainingInformation.IMAGE_W + const imageChannels = 3 + + const model = tf.sequential() + + model.add( + tf.layers.conv2d({ + inputShape: [imageHeight, imageWidth, imageChannels], + kernelSize: 5, + filters: 8, + activation: 'relu', + kernelInitializer: tf.initializers.heNormal({ seed }) + }) + ) + model.add(tf.layers.conv2d({ + kernelSize: 5, filters: 16, activation: 'relu', + kernelInitializer: tf.initializers.heNormal({ seed }) + })) + model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 })) + model.add(tf.layers.dropout({ rate: 0.25, seed })) + + model.add(tf.layers.flatten()) + model.add(tf.layers.dense({ + units: 32, activation: 'relu', + kernelInitializer: tf.initializers.heNormal({ seed }) + })) + model.add(tf.layers.dropout({rate:0.25, seed})) + model.add(tf.layers.dense({ + units: 2, activation: 'softmax', + kernelInitializer: tf.initializers.heNormal({ seed }) + })) + + model.compile({ + optimizer: tf.train.adam(0.0005), + loss: 'categoricalCrossentropy', + metrics: ['accuracy'] + }) + + return Promise.resolve(new models.TFJS('image', model)) + } +} \ No newline at end of file diff --git a/discojs/src/task/task_handler.ts b/discojs/src/task/task_handler.ts index 7b38a6928..d2b692bb4 100644 --- a/discojs/src/task/task_handler.ts +++ b/discojs/src/task/task_handler.ts @@ -20,7 +20,7 @@ export async function pushTask( task: Task, model: Model, ): Promise { - await fetch(urlToTasks(base), { + const response = await fetch(urlToTasks(base), { method: "POST", body: JSON.stringify({ task, @@ -28,12 +28,14 @@ export async function pushTask( weights: await serialization.weights.encode(model.weights), }), }); + if (!response.ok) throw new Error(`fetch: HTTP status ${response.status}`); } export async function fetchTasks( base: URL, ): Promise>> { const response = await fetch(urlToTasks(base)); + if (!response.ok) throw new Error(`fetch: HTTP status ${response.status}`); const tasks: unknown = await response.json(); if (!Array.isArray(tasks)) { diff --git a/discojs/src/training/disco.ts b/discojs/src/training/disco.ts index 2ef90728e..07d37c68c 100644 --- a/discojs/src/training/disco.ts +++ b/discojs/src/training/disco.ts @@ -205,18 +205,21 @@ export class Disco extends EventEmitter<{ ): Promise< [ Dataset>, - Dataset>, + Dataset> | undefined, ] > { const { batchSize, validationSplit } = this.#task.trainingInformation; - const preprocessed = await processing.preprocess(this.#task, dataset); + let preprocessed = await processing.preprocess(this.#task, dataset); - const [training, validation] = ( + preprocessed = ( this.#preprocessOnce ? new Dataset(await arrayFromAsync(preprocessed)) : preprocessed - ).split(validationSplit); + ) + if (validationSplit === 0) return [preprocessed.batch(batchSize).cached(), undefined]; + + const [training, validation] = preprocessed.split(validationSplit); return [ training.batch(batchSize).cached(), @@ -230,4 +233,4 @@ async function arrayFromAsync(iter: AsyncIterable): Promise { const ret: T[] = []; for await (const e of iter) ret.push(e); return ret; -} +} \ No newline at end of file diff --git a/server/src/controllers/federated_controller.ts b/server/src/controllers/federated_controller.ts index c58a05328..0e5f92d71 100644 --- a/server/src/controllers/federated_controller.ts +++ b/server/src/controllers/federated_controller.ts @@ -32,9 +32,9 @@ export class FederatedController< */ #latestGlobalWeights: serialization.Encoded; - constructor(task: Task, initialWeights: serialization.Encoded) { + constructor(task: Task, private readonly initialWeights: serialization.Encoded) { super(task) - this.#latestGlobalWeights = initialWeights + this.#latestGlobalWeights = this.initialWeights // Save the latest weight updates to be able to send it to new or outdated clients this.#aggregator.on('aggregation', async (weightUpdate) => { @@ -145,6 +145,13 @@ export class FederatedController< this.#aggregator.removeNode(clientId) debug("client [%s] left", shortId) + // Reset the training session when all participants left + if (this.connections.size === 0) { + debug("All participants left. Resetting the training session") + this.#aggregator = new aggregators.MeanAggregator(undefined, 1, 'relative') + this.#latestGlobalWeights = this.initialWeights + } + // Check if we dropped below the minimum number of participant required // or if we are already waiting for new participants to join if (this.connections.size >= minNbOfParticipants || diff --git a/webapp/src/components/containers/ImageCard.vue b/webapp/src/components/containers/ImageCard.vue index aeeee265f..623cb4070 100644 --- a/webapp/src/components/containers/ImageCard.vue +++ b/webapp/src/components/containers/ImageCard.vue @@ -1,8 +1,8 @@