Skip to content

Commit

Permalink
discojs/types: add DataFormat namespace
Browse files Browse the repository at this point in the history
  • Loading branch information
tharvik committed Oct 23, 2024
1 parent 7730877 commit cca6dc3
Show file tree
Hide file tree
Showing 20 changed files with 160 additions and 136 deletions.
11 changes: 9 additions & 2 deletions cli/src/cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@ import "@tensorflow/tfjs-node"
import { List, Range } from 'immutable'
import fs from 'node:fs/promises'

import type { Dataset, DataType, Raw, RoundLogs, Task, TaskProvider } from '@epfml/discojs'
import type {
Dataset,
DataFormat,
DataType,
RoundLogs,
Task,
TaskProvider,
} from "@epfml/discojs";
import { Disco, aggregator as aggregators, client as clients } from '@epfml/discojs'
import { Server } from 'server'

Expand All @@ -21,7 +28,7 @@ async function arrayFromAsync<T>(iter: AsyncIterable<T>): Promise<T[]> {
async function runUser<D extends DataType>(
task: Task<D>,
url: URL,
data: Dataset<Raw[D]>,
data: Dataset<DataFormat.Raw[D]>,
): Promise<List<RoundLogs>> {
const trainingScheme = task.trainingInformation.scheme
const aggregator = aggregators.getAggregator(task)
Expand Down
22 changes: 14 additions & 8 deletions cli/src/data.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import path from "node:path";

import type { Dataset, DataType, Image, Raw, Task } from "@epfml/discojs";
import type {
Dataset,
DataFormat,
DataType,
Image,
Task,
} from "@epfml/discojs";
import { loadCSV, loadImagesInDir } from "@epfml/discojs-node";
import { Repeat } from "immutable";

async function loadSimpleFaceData(): Promise<Dataset<Raw["image"]>> {
async function loadSimpleFaceData(): Promise<Dataset<DataFormat.Raw["image"]>> {
const folder = path.join("..", "datasets", "simple_face");

const [adults, childs]: Dataset<[Image, string]>[] = [
Expand All @@ -15,7 +21,7 @@ async function loadSimpleFaceData(): Promise<Dataset<Raw["image"]>> {
return adults.chain(childs);
}

async function loadLusCovidData(): Promise<Dataset<Raw["image"]>> {
async function loadLusCovidData(): Promise<Dataset<DataFormat.Raw["image"]>> {
const folder = path.join("..", "datasets", "lus_covid");

const [positive, negative]: Dataset<[Image, string]>[] = [
Expand All @@ -32,20 +38,20 @@ async function loadLusCovidData(): Promise<Dataset<Raw["image"]>> {

export async function getTaskData<D extends DataType>(
task: Task<D>,
): Promise<Dataset<Raw[D]>> {
): Promise<Dataset<DataFormat.Raw[D]>> {
switch (task.id) {
case "simple_face":
return (await loadSimpleFaceData()) as Dataset<Raw[D]>;
return (await loadSimpleFaceData()) as Dataset<DataFormat.Raw[D]>;
case "titanic":
return loadCSV(
path.join("..", "datasets", "titanic_train.csv"),
) as Dataset<Raw[D]>;
) as Dataset<DataFormat.Raw[D]>;
case "cifar10":
return (
await loadImagesInDir(path.join("..", "datasets", "CIFAR10"))
).zip(Repeat("cat")) as Dataset<Raw[D]>;
).zip(Repeat("cat")) as Dataset<DataFormat.Raw[D]>;
case "lus_covid":
return (await loadLusCovidData()) as Dataset<Raw[D]>;
return (await loadLusCovidData()) as Dataset<DataFormat.Raw[D]>;
default:
throw new Error(`Data loader for ${task.id} not implemented.`);
}
Expand Down
2 changes: 1 addition & 1 deletion discojs/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ export * as async_iterator from "./utils/async_iterator.js"
export { EventEmitter } from "./utils/event_emitter.js"

export * from "./dataset/index.js";
export * from "./types.js";
export * from "./types/index.js";

export * as processing from "./processing/index.js";
4 changes: 2 additions & 2 deletions discojs/src/models/gpt/gpt.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { expect } from "chai";
import "@tensorflow/tfjs-node"; // speed up
import { AutoTokenizer } from "@xenova/transformers";

import { Dataset, ModelEncoded } from "../../index.js";
import { Dataset, DataFormat } from "../../index.js";

import { GPT } from "./index.js";
import { List, Repeat } from "immutable";
Expand All @@ -17,7 +17,7 @@ describe("gpt-tfjs", function () {
(tokenizer(data, { return_tensor: false }) as { input_ids: number[] })
.input_ids,
);
const dataset = new Dataset<ModelEncoded["text"]>(
const dataset = new Dataset<DataFormat.ModelEncoded["text"]>(
Repeat([dataTokens.pop(), dataTokens.last()]),
).batch(64);

Expand Down
22 changes: 12 additions & 10 deletions discojs/src/models/gpt/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import createDebug from "debug";
import { List, Range } from "immutable";
import * as tf from "@tensorflow/tfjs";

import type { Batched, Dataset, ModelEncoded } from "../../index.js";
import type { Batched, Dataset, DataFormat } from "../../index.js";
import { WeightsContainer } from "../../index.js";

import { BatchLogs, Model, EpochLogs } from "../index.js";
Expand Down Expand Up @@ -57,8 +57,8 @@ export class GPT extends Model<"text"> {
* @param tracker
*/
override async *train(
trainingDataset: Dataset<Batched<ModelEncoded["text"]>>,
validationDataset?: Dataset<Batched<ModelEncoded["text"]>>,
trainingDataset: Dataset<Batched<DataFormat.ModelEncoded["text"]>>,
validationDataset?: Dataset<Batched<DataFormat.ModelEncoded["text"]>>,
): AsyncGenerator<BatchLogs, EpochLogs> {
let batchesLogs = List<BatchLogs>();

Expand All @@ -77,7 +77,9 @@ export class GPT extends Model<"text"> {
return new EpochLogs(batchesLogs, validation);
}

async #runBatch(batch: Batched<ModelEncoded["text"]>): Promise<BatchLogs> {
async #runBatch(
batch: Batched<DataFormat.ModelEncoded["text"]>,
): Promise<BatchLogs> {
const tfBatch = this.#batchToTF(batch);

let logs: tf.Logs | undefined;
Expand Down Expand Up @@ -105,7 +107,7 @@ export class GPT extends Model<"text"> {
}

async #evaluate(
dataset: Dataset<Batched<ModelEncoded["text"]>>,
dataset: Dataset<Batched<DataFormat.ModelEncoded["text"]>>,
): Promise<Record<"accuracy" | "loss", number>> {
const evaluation = await evaluate(
this.model,
Expand All @@ -119,7 +121,7 @@ export class GPT extends Model<"text"> {
};
}

#batchToTF(batch: Batched<ModelEncoded["text"]>): {
#batchToTF(batch: Batched<DataFormat.ModelEncoded["text"]>): {
xs: tf.Tensor2D;
ys: tf.Tensor3D;
} {
Expand All @@ -138,9 +140,9 @@ export class GPT extends Model<"text"> {
}

override async predict(
batch: Batched<ModelEncoded["text"][0]>,
batch: Batched<DataFormat.ModelEncoded["text"][0]>,
options?: Partial<PredictConfig>,
): Promise<Batched<ModelEncoded["text"][1]>> {
): Promise<Batched<DataFormat.ModelEncoded["text"][1]>> {
const config = {
temperature: 1.0,
doSample: false,
Expand All @@ -155,9 +157,9 @@ export class GPT extends Model<"text"> {
}

async #predictSingle(
tokens: ModelEncoded["text"][0],
tokens: DataFormat.ModelEncoded["text"][0],
config: PredictConfig,
): Promise<ModelEncoded["text"][1]> {
): Promise<DataFormat.ModelEncoded["text"][1]> {
// slice input tokens if longer than context length
tokens = tokens.slice(-this.#blockSize);

Expand Down
10 changes: 5 additions & 5 deletions discojs/src/models/model.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import type {
Batched,
Dataset,
DataFormat,
DataType,
ModelEncoded,
WeightsContainer,
} from "../index.js";

Expand All @@ -29,15 +29,15 @@ export abstract class Model<D extends DataType> implements Disposable {
* @yields on every epoch, training can be stop by `return`ing or `throw`ing it
*/
abstract train(
trainingDataset: Dataset<Batched<ModelEncoded[D]>>,
validationDataset?: Dataset<Batched<ModelEncoded[D]>>,
trainingDataset: Dataset<Batched<DataFormat.ModelEncoded[D]>>,
validationDataset?: Dataset<Batched<DataFormat.ModelEncoded[D]>>,
): AsyncGenerator<BatchLogs, EpochLogs>;

/** Predict likely values */
// TODO extract in separated TrainedModel?
abstract predict(
batch: Batched<ModelEncoded[D][0]>,
): Promise<Batched<ModelEncoded[D][1]>>;
batch: Batched<DataFormat.ModelEncoded[D][0]>,
): Promise<Batched<DataFormat.ModelEncoded[D][1]>>;

/**
* This method is automatically called to cleanup the memory occupied by the model
Expand Down
30 changes: 17 additions & 13 deletions discojs/src/models/tfjs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import * as tf from '@tensorflow/tfjs'
import {
Batched,
Dataset,
DataFormat,
DataType,
ModelEncoded,
WeightsContainer,
} from "../index.js";

Expand Down Expand Up @@ -40,8 +40,8 @@ export class TFJS<D extends "image" | "tabular"> extends Model<D> {
}

override async *train(
trainingDataset: Dataset<Batched<ModelEncoded[D]>>,
validationDataset?: Dataset<Batched<ModelEncoded[D]>>,
trainingDataset: Dataset<Batched<DataFormat.ModelEncoded[D]>>,
validationDataset?: Dataset<Batched<DataFormat.ModelEncoded[D]>>,
): AsyncGenerator<BatchLogs, EpochLogs> {
let batchesLogs = List<BatchLogs>();
for await (const [batch, batchNumber] of trainingDataset.zip(Range())) {
Expand All @@ -60,7 +60,7 @@ export class TFJS<D extends "image" | "tabular"> extends Model<D> {
}

async #runBatch(
batch: Batched<ModelEncoded[D]>,
batch: Batched<DataFormat.ModelEncoded[D]>,
): Promise<Omit<BatchLogs, "batch">> {
const { xs, ys } = this.#batchToTF(batch);

Expand Down Expand Up @@ -88,7 +88,7 @@ export class TFJS<D extends "image" | "tabular"> extends Model<D> {
}

async #evaluate(
dataset: Dataset<Batched<ModelEncoded[D]>>,
dataset: Dataset<Batched<DataFormat.ModelEncoded[D]>>,
): Promise<Record<"accuracy" | "loss", number>> {
const evaluation = await this.model.evaluateDataset(
intoTFDataset(dataset.map((batch) => this.#batchToTF(batch))),
Expand Down Expand Up @@ -116,8 +116,8 @@ export class TFJS<D extends "image" | "tabular"> extends Model<D> {
}

override async predict(
batch: Batched<ModelEncoded[D][0]>,
): Promise<Batched<ModelEncoded[D][1]>> {
batch: Batched<DataFormat.ModelEncoded[D][0]>,
): Promise<Batched<DataFormat.ModelEncoded[D][1]>> {
async function cleanupPredicted(y: tf.Tensor1D): Promise<number> {
if (y.shape[0] === 1) {
// Binary classification
Expand Down Expand Up @@ -215,7 +215,9 @@ export class TFJS<D extends "image" | "tabular"> extends Model<D> {
return this.model
}

#batchToTF(batch: Batched<ModelEncoded[D]>): Record<"xs" | "ys", tf.Tensor> {
#batchToTF(
batch: Batched<DataFormat.ModelEncoded[D]>,
): Record<"xs" | "ys", tf.Tensor> {
const outputSize = tf.util.sizeFromShape(
this.model.outputShape.map((dim) => {
if (Array.isArray(dim))
Expand All @@ -227,7 +229,7 @@ export class TFJS<D extends "image" | "tabular"> extends Model<D> {
switch (this.datatype) {
case "image": {
// cast as typescript doesn't reduce generic type
const b = batch as Batched<ModelEncoded["image"]>;
const b = batch as Batched<DataFormat.ModelEncoded["image"]>;

return tf.tidy(() => ({
xs: tf.stack(
Expand All @@ -250,7 +252,7 @@ export class TFJS<D extends "image" | "tabular"> extends Model<D> {
}
case "tabular": {
// cast as typescript doesn't reduce generic type
const b = batch as Batched<ModelEncoded["tabular"]>;
const b = batch as Batched<DataFormat.ModelEncoded["tabular"]>;

return tf.tidy(() => ({
xs: tf.stack(
Expand All @@ -265,11 +267,13 @@ export class TFJS<D extends "image" | "tabular"> extends Model<D> {
throw new Error("should never happen");
}

#batchWithoutLabelToTF(batch: Batched<ModelEncoded[D][0]>): tf.Tensor {
#batchWithoutLabelToTF(
batch: Batched<DataFormat.ModelEncoded[D][0]>,
): tf.Tensor {
switch (this.datatype) {
case "image": {
// cast as typescript doesn't reduce generic type
const b = batch as Batched<ModelEncoded["image"][0]>;
const b = batch as Batched<DataFormat.ModelEncoded["image"][0]>;

return tf.tidy(() => tf.stack(
b
Expand All @@ -286,7 +290,7 @@ export class TFJS<D extends "image" | "tabular"> extends Model<D> {
}
case "tabular": {
// cast as typescript doesn't reduce generic type
const b = batch as Batched<ModelEncoded["tabular"][0]>;
const b = batch as Batched<DataFormat.ModelEncoded["tabular"][0]>;

return tf.tidy(() =>
tf.stack(
Expand Down
Loading

0 comments on commit cca6dc3

Please sign in to comment.