Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Emmanuel confusion matrix #833

Open
wants to merge 19 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
5191140
Return the predicted label in the tests
tomasoignons Nov 21, 2024
8bb5b8c
Added the confusion matrix generation
tomasoignons Nov 21, 2024
ee6d844
Added dark mode support and labels for the confusion matrix
tomasoignons Nov 24, 2024
696c3f8
remove logs and unused variables
tomasoignons Nov 24, 2024
2debf68
merge commit
tomasoignons Dec 16, 2024
4577a0c
setting the map and set to immutable, and adding a pseudo confusion m…
tomasoignons Dec 16, 2024
2854bd8
Simplify code
tomasoignons Dec 16, 2024
6762c1a
Merge branch 'emmanuel-confusion-matrix' of github.com:epfml/disco in…
tomasoignons Dec 16, 2024
7839dfc
Simplfy code by removing a div
tomasoignons Dec 16, 2024
8b0bc1f
reverse the truth and label labels, really confusing
tomasoignons Dec 16, 2024
0fc88ee
removed the ugly way to compute labels, because we have them all in t…
tomasoignons Dec 16, 2024
dae4ba6
added matrix for tabular binary classification
tomasoignons Dec 16, 2024
30d84a4
small mistake between columns and row
tomasoignons Dec 16, 2024
78965b7
remove some spaces on the rendering of the matrix
tomasoignons Dec 16, 2024
14918ef
modify the validator type, which causes a whole refractor of the way …
tomasoignons Dec 17, 2024
7b90480
remove console.log and unused variables
tomasoignons Dec 17, 2024
817ecb5
remove a span element that doesn't do anything
tomasoignons Dec 17, 2024
53ada75
modification of the validator to make it more clear
tomasoignons Dec 17, 2024
d63fc11
modification of the timeout of the test of the validator 'can read an…
tomasoignons Dec 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions discojs/src/validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,29 @@ export class Validator<D extends DataType> {
/** infer every line of the dataset and check that it is as labelled */
async *test(
dataset: Dataset<DataFormat.Raw[D]>,
): AsyncGenerator<boolean, void> {
const results = (await processing.preprocess(this.task, dataset))
.batch(this.task.trainingInformation.batchSize)
): AsyncGenerator<{ result: boolean; predicted: DataFormat.Inferred[D]; truth : number }, void> {
const preprocessed = await processing.preprocess(this.task, dataset);
const batched = preprocessed.batch(this.task.trainingInformation.batchSize);

const initialResults = batched
.map(async (batch) =>
(await this.#model.predict(batch.map(([inputs, _]) => inputs)))
.zip(batch.map(([_, outputs]) => outputs))
.map(([inferred, truth]) => inferred === truth),
.map(([inferred, truth]) => ({ result: inferred === truth, predicted: inferred, truth })),
)
.flatten();

for await (const e of results) yield e;
const predictions = await processing.postprocess(
this.task,
initialResults.map(({ predicted }) => predicted),
);

const finalResults = initialResults.zip(predictions).map(([result, predicted]) => ({
...result,
predicted,
}));

for await (const e of finalResults) yield e;
}

/** use the model to predict every line of the dataset */
Expand Down
2 changes: 1 addition & 1 deletion server/tests/validator.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ describe("validator", () => {
}

expect(hits / size).to.be.greaterThan(0.3);
}).timeout("5s");
}).timeout("10s");

it("can read and predict randomly on titanic", async () => {
const provider = defaultTasks.titanic;
Expand Down
125 changes: 106 additions & 19 deletions webapp/src/components/testing/TestSteps.vue
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,35 @@
</div>
</div>


<div v-if="confusionMatrix && confusionMatrix.matrix && Object.keys(confusionMatrix.matrix).length > 0" class="p-4 mx-auto lg:w-1/2 h-full bg-white dark:bg-slate-950 rounded-md">
<h4 class="p-4 text-lg font-semibold text-slate-500 dark:text-slate-300">
Confusion Matrix
</h4>
<table class="min-w-full divide-y divide-slate-600 dark:divide-slate-400 text-center">
<thead>
<tr>
<th class="px-0 py-3 text-xs font-medium text-gray-800 dark:text-gray-200 uppercase tracking-wider text-center border-r-gray-600 dark:border-r-gray-400 border-r-2 diagonal-header">
Label \ Prediction
</th>
<th v-for="(label, index) in Object.keys(confusionMatrix.matrix)" :key="'header-' + index" class="text-xs font-medium text-gray-800 dark:text-gray-200 uppercase tracking-wider">
{{ label }}
</th>
</tr>
</thead>
<tbody>
<tr v-for="(rowLabel, rowIndex) in Object.keys(confusionMatrix.matrix)" :key="'row-' + rowIndex">
<td class="py-2 whitespace-nowrap text-sm font-medium text-gray-800 dark:text-gray-200 border-r-gray-600 dark:border-r-gray-400 border-r-2">
{{ rowLabel }}
</td>
<td v-for="(colLabel, colIndex) in Object.keys(confusionMatrix.matrix[rowLabel])" :key="'col-' + colIndex" class="whitespace-nowrap text-sm dark:text-gray-300 text-gray-700">
{{ confusionMatrix.matrix[rowLabel][colLabel] }}
</td>
</tr>
</tbody>
</table>
</div>

<div v-if="tested !== undefined">
<div class="mx-auto lg:w-1/2 text-center pb-8">
<CustomButton @click="saveCsv()"> download as csv </CustomButton>
Expand Down Expand Up @@ -96,7 +125,7 @@
"
:rows="
(tested as Tested['tabular']).results.map(({ input, output }) =>
input.concat(output.truth).push(output.correct.toString()),
input.concat(output.label).push(output.correct.toString()),
)
"
/>
Expand Down Expand Up @@ -129,6 +158,7 @@ import ImageCard from "@/components/containers/ImageCard.vue";
import LabeledDatasetInput from "@/components/dataset_input/LabeledDatasetInput.vue";
import TableLayout from "@/components/containers/TableLayout.vue";
import type { LabeledDataset } from "@/components/dataset_input/types.js";
import { Map } from 'immutable';

const debug = createDebug("webapp:testing:TestSteps");
const toaster = useToaster();
Expand All @@ -142,24 +172,24 @@ const props = defineProps<{
interface Tested {
image: List<{
input: { filename: string; image: ImageData };
output: { truth: string; correct: boolean };
output: { truth: number; correct: boolean; predicted : string, label : string };
}>;
tabular: {
labels: {
input: List<string>;
output: { truth: string; correct: string };
output: { truth: string; correct: string, label : string };
};
results: List<{
input: List<string>;
output: { truth: string; correct: boolean };
output: { truth: number; correct: boolean; predicted : number, label : string };
}>;
};
// TODO what to show?
text: List<{ output: { correct: boolean } }>;
}

const dataset = ref<LabeledDataset[D]>();
const generator = ref<AsyncGenerator<boolean, void>>();
const generator = ref<AsyncGenerator<{result : boolean, predicted : string | number; truth : number}, void>>();
const tested = ref<Tested[D]>();

const visitedSamples = computed<number>(() => {
Expand All @@ -177,9 +207,62 @@ const visitedSamples = computed<number>(() => {
}
}
});

const confusionMatrix = computed<{ labels: Map<number, string>; matrix: { [key: string]: { [key: string]: number } } }>(() => {
if (tested.value === undefined) {
return { labels: Map<number, string>(), matrix: {} };
}
let labels : string[] = [];
switch (props.task.trainingInformation.dataType) {
case "image" :
labels = (props.task as Task<"image">).trainingInformation.LABEL_LIST;
break;
case "tabular" :
labels = ["0", "1"]; // binary classification
break;
case "text" :
return { labels: Map<number, string>(), matrix: {} }; default: {
const _: never = props.task.trainingInformation;
throw new Error("should never happen");
}
}

// Initialize the confusion matrix
const matrix: { [key: string]: { [key: string]: number } } = {};

// Initialize the confusion matrix
labels.forEach((label) => {
matrix[label.toString()] = {};
labels.forEach((innerLabel) => {
matrix[label.toString()][innerLabel.toString()] = 0;
});
});

switch (props.task.trainingInformation.dataType) {
case "image":
(tested.value as Tested["image"]).map(
( {output} ) => matrix[output.label][output.predicted] = matrix[output.label][output.predicted] + 1,
);
break;
//case "text":
// return undefined;
case "tabular":
(tested.value as Tested["tabular"]).results.map(
({ output }) => matrix[output.truth][output.predicted] = matrix[output.truth][output.predicted] + 1,
);
break;
default: {
const _: never = props.task.trainingInformation;
throw new Error("should never happen");
}
}
const mapLabels = Map(labels.map((label, index) => [index, label]));
return {labels : mapLabels, matrix : matrix};
})

const currentAccuracy = computed<string>(() => {
if (tested.value === undefined) return "0";

if (tested.value === undefined) return "0";
let hits: number | undefined;
switch (props.task.trainingInformation.dataType) {
case "image":
Expand Down Expand Up @@ -245,12 +328,12 @@ async function startImageTest(
): Promise<void> {
const validator = new Validator(task, model);
let results: Tested["image"] = List();

try {
generator.value = validator.test(
dataset.map(({ image, label }) => [image, label] as [Image, string]),
);
for await (const [{ filename, image, label }, correct] of dataset.zip(
for await (const [{ filename, image, label }, {result, predicted, truth}] of dataset.zip(
toRaw(generator.value),
)) {
results = results.push({
Expand All @@ -263,8 +346,10 @@ async function startImageTest(
),
},
output: {
truth: label,
correct,
label: label,
correct: result,
predicted: String(predicted),
truth : truth,
},
});

Expand Down Expand Up @@ -295,9 +380,9 @@ async function startTabularTest(
let results: Tested["tabular"]["results"] = List();
try {
generator.value = validator.test(dataset);
for await (const [row, correct] of dataset.zip(toRaw(generator.value))) {
const truth = row[outputColumn];
if (truth === undefined)
for await (const [row, {result, predicted, truth}] of dataset.zip(toRaw(generator.value))) {
const truth_label = row[outputColumn];
if (truth_label === undefined)
throw new Error("row doesn't have expected output column");

results = results.push({
Expand All @@ -308,8 +393,10 @@ async function startTabularTest(
return ret;
}),
output: {
truth,
correct,
truth: truth,
correct: result,
predicted : Number(predicted),
label : truth_label,
},
});

Expand All @@ -330,8 +417,8 @@ async function startTextTest(

try {
generator.value = validator.test(dataset);
for await (const correct of toRaw(generator.value)) {
results = results.push({ output: { correct } });
for await (const {result} of toRaw(generator.value)) {
results = results.push({ output: {correct : result} });
tested.value = results as Tested[D];
}
} finally {
Expand Down Expand Up @@ -368,7 +455,7 @@ function saveCsv(): void {
["Filename", "Truth", "Correct"],
...(tested as Tested["image"]).map(({ input, output }) => [
input.filename,
output.truth,
output.label,
output.correct.toString(),
]),
]);
Expand All @@ -381,7 +468,7 @@ function saveCsv(): void {
.toArray(),
...t.results.map((result) =>
result.input
.concat(result.output.truth)
.concat(result.output.label)
.push(result.output.correct.toString())
.toArray(),
),
Expand Down
Loading