Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Emmanuel confusion matrix #833

Open
wants to merge 19 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
5191140
Return the predicted label in the tests
tomasoignons Nov 21, 2024
8bb5b8c
Added the confusion matrix generation
tomasoignons Nov 21, 2024
ee6d844
Added dark mode support and labels for the confusion matrix
tomasoignons Nov 24, 2024
696c3f8
remove logs and unused variables
tomasoignons Nov 24, 2024
2debf68
merge commit
tomasoignons Dec 16, 2024
4577a0c
setting the map and set to immutable, and adding a pseudo confusion m…
tomasoignons Dec 16, 2024
2854bd8
Simplify code
tomasoignons Dec 16, 2024
6762c1a
Merge branch 'emmanuel-confusion-matrix' of github.com:epfml/disco in…
tomasoignons Dec 16, 2024
7839dfc
Simplfy code by removing a div
tomasoignons Dec 16, 2024
8b0bc1f
reverse the truth and label labels, really confusing
tomasoignons Dec 16, 2024
0fc88ee
removed the ugly way to compute labels, because we have them all in t…
tomasoignons Dec 16, 2024
dae4ba6
added matrix for tabular binary classification
tomasoignons Dec 16, 2024
30d84a4
small mistake between columns and row
tomasoignons Dec 16, 2024
78965b7
remove some spaces on the rendering of the matrix
tomasoignons Dec 16, 2024
14918ef
modify the validator type, which causes a whole refractor of the way …
tomasoignons Dec 17, 2024
7b90480
remove console.log and unused variables
tomasoignons Dec 17, 2024
817ecb5
remove a span element that doesn't do anything
tomasoignons Dec 17, 2024
53ada75
modification of the validator to make it more clear
tomasoignons Dec 17, 2024
d63fc11
modification of the timeout of the test of the validator 'can read an…
tomasoignons Dec 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions discojs/src/validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ export class Validator<D extends DataType> {
/** infer every line of the dataset and check that it is as labelled */
async *test(
dataset: Dataset<DataFormat.Raw[D]>,
): AsyncGenerator<boolean, void> {
): AsyncGenerator<{ result: boolean; predicted: DataFormat.ModelEncoded[D][1]; truth : DataFormat.ModelEncoded[D][1] }, void> {
tomasoignons marked this conversation as resolved.
Show resolved Hide resolved
const results = (await processing.preprocess(this.task, dataset))
.batch(this.task.trainingInformation.batchSize)
.map(async (batch) =>
(await this.#model.predict(batch.map(([inputs, _]) => inputs)))
.zip(batch.map(([_, outputs]) => outputs))
.map(([inferred, truth]) => inferred === truth),
)
.map(([inferred, truth]) => ({ result: inferred === truth, predicted: inferred, truth : truth })),
)
.unbatch();

for await (const e of results) yield e;
Expand Down
111 changes: 97 additions & 14 deletions webapp/src/components/testing/TestSteps.vue
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,34 @@
</div>
</div>

<div v-if="confusionMatrix && confusionMatrix.matrix.length > 0" class="p-4 mx-auto lg:w-1/2 h-full bg-white dark:bg-slate-950 rounded-md">
<h4 class="p-4 text-lg font-semibold text-slate-500 dark:text-slate-300">
Confusion Matrix
</h4>
<table class="min-w-full divide-y divide-slate-600 dark:divide-slate-400 text-center">
<thead>
<tr>
<th class="pl-6 py-3 text-xs font-medium text-gray-800 dark:text-gray-200 uppercase tracking-wider text-center border-r-gray-600 dark:border-r-gray-400 border-r-2 diagonal-header">
<span class="">Label \ Prediction</span>
tomasoignons marked this conversation as resolved.
Show resolved Hide resolved
</th>
<th v-for="(label, index) in confusionMatrix.matrix[0]" :key="'header-' + index" class="px-6 py-3 text-xs font-medium text-gray-800 dark:text-gray-200 uppercase tracking-wider">
tomasoignons marked this conversation as resolved.
Show resolved Hide resolved
{{ confusionMatrix.labels.get(index) }}
</th>
</tr>
</thead>
<tbody>
<tr v-for="(row, rowIndex) in confusionMatrix.matrix" :key="'row-' + rowIndex">
<td class="px-6 py-4 whitespace-nowrap text-sm font-medium text-gray-800 dark:text-gray-200 border-r-gray-600 dark:border-r-gray-400 border-r-2">
{{ confusionMatrix.labels.get(rowIndex) }}
</td>
<td v-for="(value, colIndex) in row" :key="'col-' + colIndex" class="px-6 py-4 whitespace-nowrap text-sm dark:text-gray-300 text-gray-700">
{{ value }}
</td>
</tr>
</tbody>
</table>
</div>

<div v-if="tested !== undefined">
<div class="mx-auto lg:w-1/2 text-center pb-8">
<CustomButton @click="saveCsv()"> download as csv </CustomButton>
Expand Down Expand Up @@ -142,24 +170,24 @@ const props = defineProps<{
interface Tested {
image: List<{
input: { filename: string; image: ImageData };
output: { truth: string; correct: boolean };
output: { truth: string; correct: boolean; predicted : number, label : number };
}>;
tabular: {
labels: {
input: List<string>;
output: { truth: string; correct: string };
output: { truth: string; correct: string, label : string };
};
results: List<{
input: List<string>;
output: { truth: string; correct: boolean };
output: { truth: string; correct: boolean; predicted : number, label : number };
tomasoignons marked this conversation as resolved.
Show resolved Hide resolved
}>;
};
// TODO what to show?
text: List<{ output: { correct: boolean } }>;
}

const dataset = ref<LabeledDataset[D]>();
const generator = ref<AsyncGenerator<boolean, void>>();
const generator = ref<AsyncGenerator<{result : boolean, predicted : number; truth : number}, void>>();
const tested = ref<Tested[D]>();

const visitedSamples = computed<number>(() => {
Expand All @@ -177,9 +205,60 @@ const visitedSamples = computed<number>(() => {
}
}
});

const confusionMatrix = computed<{labels : Map<number, string>, matrix : number[][]} | undefined>(() => {
if (tested.value === undefined) return undefined;
const labels = new Set<number>();
const mapLabels = new Map<number, string>();
tomasoignons marked this conversation as resolved.
Show resolved Hide resolved

// get all the labels
switch (props.task.trainingInformation.dataType) {
case "image":
(tested.value as Tested["image"]).forEach(({ output }) => {
labels.add(output.label);
labels.add(output.predicted);
mapLabels.set(output.label, output.truth);
});
break;
case "text":
return undefined;
case "tabular":
(tested.value as Tested["tabular"]).results.forEach(({ output }) => {
labels.add(output.label);
labels.add(output.predicted);
mapLabels.set(output.label, output.truth);
});
break;
tomasoignons marked this conversation as resolved.
Show resolved Hide resolved
default: {
const _: never = props.task.trainingInformation;
throw new Error("should never happen");
}
}
const size = Math.max(labels.size, Math.max(...Array.from(labels)));
tomasoignons marked this conversation as resolved.
Show resolved Hide resolved
// Initialize the confusion matrix
const matrix = Array.from({ length: size }, () => Array(size).fill(0));

switch (props.task.trainingInformation.dataType) {
case "image":
(tested.value as Tested["image"]).map(
( {output} ) => matrix[output.predicted][output.label] = matrix[output.predicted][output.label] + 1,
);
break;
//case "text":
// return undefined;
case "tabular":
return undefined;
tomasoignons marked this conversation as resolved.
Show resolved Hide resolved
default: {
const _: never = props.task.trainingInformation;
throw new Error("should never happen");
}
}
return {labels : mapLabels, matrix : matrix};
})

const currentAccuracy = computed<string>(() => {
if (tested.value === undefined) return "0";

if (tested.value === undefined) return "0";
let hits: number | undefined;
switch (props.task.trainingInformation.dataType) {
case "image":
Expand Down Expand Up @@ -250,7 +329,7 @@ async function startImageTest(
generator.value = validator.test(
dataset.map(({ image, label }) => [image, label] as [Image, string]),
);
for await (const [{ filename, image, label }, correct] of dataset.zip(
for await (const [{ filename, image, label }, {result, predicted, truth}] of dataset.zip(
toRaw(generator.value),
)) {
results = results.push({
Expand All @@ -264,7 +343,9 @@ async function startImageTest(
},
output: {
truth: label,
correct,
correct: result,
predicted: predicted,
label : truth,
tomasoignons marked this conversation as resolved.
Show resolved Hide resolved
},
});

Expand Down Expand Up @@ -295,9 +376,9 @@ async function startTabularTest(
let results: Tested["tabular"]["results"] = List();
try {
generator.value = validator.test(dataset);
for await (const [row, correct] of dataset.zip(toRaw(generator.value))) {
const truth = row[outputColumn];
if (truth === undefined)
for await (const [row, {result, predicted, truth}] of dataset.zip(toRaw(generator.value))) {
const truth_label = row[outputColumn];
if (truth_label === undefined)
throw new Error("row doesn't have expected output column");

results = results.push({
Expand All @@ -308,8 +389,10 @@ async function startTabularTest(
return ret;
}),
output: {
truth,
correct,
truth: truth_label,
correct: result,
predicted : predicted,
label : truth,
},
});

Expand All @@ -330,8 +413,8 @@ async function startTextTest(

try {
generator.value = validator.test(dataset);
for await (const correct of toRaw(generator.value)) {
results = results.push({ output: { correct } });
for await (const {result} of toRaw(generator.value)) {
results = results.push({ output: {correct : result} });
tested.value = results as Tested[D];
}
} finally {
Expand Down