Skip to content

Commit

Permalink
modify the validator type, which causes a whole refractor of the way …
Browse files Browse the repository at this point in the history
…the matrix is handled
  • Loading branch information
tomasoignons committed Dec 17, 2024
1 parent 78965b7 commit 14918ef
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 47 deletions.
22 changes: 16 additions & 6 deletions discojs/src/validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,27 @@ export class Validator<D extends DataType> {
/** infer every line of the dataset and check that it is as labelled */
async *test(
dataset: Dataset<DataFormat.Raw[D]>,
): AsyncGenerator<{ result: boolean; predicted: DataFormat.ModelEncoded[D][1]; truth : DataFormat.ModelEncoded[D][1] }, void> {
): AsyncGenerator<{ result: boolean; predicted: DataFormat.Inferred[D]; truth : number }, void> {
const results = (await processing.preprocess(this.task, dataset))
.batch(this.task.trainingInformation.batchSize)
.map(async (batch) =>
(await this.#model.predict(batch.map(([inputs, _]) => inputs)))
.zip(batch.map(([_, outputs]) => outputs))
.map(([inferred, truth]) => ({ result: inferred === truth, predicted: inferred, truth : truth })),
)
.flatten();
.zip(batch.map(([_, outputs]) => outputs))
.map(([inferred, truth]) => ({ result: inferred === truth, predicted: inferred, truth : truth })),
)
.flatten();

const predictions = await processing.postprocess(
this.task,
results.map(({ predicted }) => predicted),
);

const finalResults = results.zip(predictions).map(([result, predicted]) => ({
...result,
predicted,
}));

for await (const e of results) yield e;
for await (const e of finalResults) yield e;
}

/** use the model to predict every line of the dataset */
Expand Down
92 changes: 51 additions & 41 deletions webapp/src/components/testing/TestSteps.vue
Original file line number Diff line number Diff line change
Expand Up @@ -58,33 +58,34 @@
</div>
</div>

<div v-if="confusionMatrix && confusionMatrix.matrix.length > 0" class="p-4 mx-auto lg:w-1/2 h-full bg-white dark:bg-slate-950 rounded-md">

<div v-if="confusionMatrix && confusionMatrix.matrix && Object.keys(confusionMatrix.matrix).length > 0" class="p-4 mx-auto lg:w-1/2 h-full bg-white dark:bg-slate-950 rounded-md">
<h4 class="p-4 text-lg font-semibold text-slate-500 dark:text-slate-300">
Confusion Matrix
</h4>
<table class="min-w-full divide-y divide-slate-600 dark:divide-slate-400 text-center">
<thead>
<tr>
<th class="px-0 py-3 text-xs font-medium text-gray-800 dark:text-gray-200 uppercase tracking-wider text-center border-r-gray-600 dark:border-r-gray-400 border-r-2 diagonal-header">
<span class="">Label \ Prediction</span>
</th>
<th v-for="(label, index) in confusionMatrix.matrix[0]" :key="'header-' + index" class="text-xs font-medium text-gray-800 dark:text-gray-200 uppercase tracking-wider">
{{ confusionMatrix.labels.get(index) }}
</th>
</tr>
</thead>
<tbody>
<tr v-for="(row, rowIndex) in confusionMatrix.matrix" :key="'row-' + rowIndex">
<td class="py-2 whitespace-nowrap text-sm font-medium text-gray-800 dark:text-gray-200 border-r-gray-600 dark:border-r-gray-400 border-r-2">
{{ confusionMatrix.labels.get(rowIndex) }}
</td>
<td v-for="(value, colIndex) in row" :key="'col-' + colIndex" class="whitespace-nowrap text-sm dark:text-gray-300 text-gray-700">
{{ value }}
</td>
</tr>
</tbody>
</table>
</div>
Confusion Matrix
</h4>
<table class="min-w-full divide-y divide-slate-600 dark:divide-slate-400 text-center">
<thead>
<tr>
<th class="px-0 py-3 text-xs font-medium text-gray-800 dark:text-gray-200 uppercase tracking-wider text-center border-r-gray-600 dark:border-r-gray-400 border-r-2 diagonal-header">
<span class="">Label \ Prediction</span>
</th>
<th v-for="(label, index) in Object.keys(confusionMatrix.matrix)" :key="'header-' + index" class="text-xs font-medium text-gray-800 dark:text-gray-200 uppercase tracking-wider">
{{ label }}
</th>
</tr>
</thead>
<tbody>
<tr v-for="(rowLabel, rowIndex) in Object.keys(confusionMatrix.matrix)" :key="'row-' + rowIndex">
<td class="py-2 whitespace-nowrap text-sm font-medium text-gray-800 dark:text-gray-200 border-r-gray-600 dark:border-r-gray-400 border-r-2">
{{ rowLabel }}
</td>
<td v-for="(colLabel, colIndex) in Object.keys(confusionMatrix.matrix[rowLabel])" :key="'col-' + colIndex" class="whitespace-nowrap text-sm dark:text-gray-300 text-gray-700">
{{ confusionMatrix.matrix[rowLabel][colLabel] }}
</td>
</tr>
</tbody>
</table>
</div>

<div v-if="tested !== undefined">
<div class="mx-auto lg:w-1/2 text-center pb-8">
Expand Down Expand Up @@ -157,7 +158,7 @@ import ImageCard from "@/components/containers/ImageCard.vue";
import LabeledDatasetInput from "@/components/dataset_input/LabeledDatasetInput.vue";
import TableLayout from "@/components/containers/TableLayout.vue";
import type { LabeledDataset } from "@/components/dataset_input/types.js";
import { Map, Set } from 'immutable';
import { Map } from 'immutable';
const debug = createDebug("webapp:testing:TestSteps");
const toaster = useToaster();
Expand All @@ -171,7 +172,7 @@ const props = defineProps<{
interface Tested {
image: List<{
input: { filename: string; image: ImageData };
output: { truth: number; correct: boolean; predicted : number, label : string };
output: { truth: number; correct: boolean; predicted : string, label : string };
}>;
tabular: {
labels: {
Expand All @@ -188,7 +189,7 @@ interface Tested {
}
const dataset = ref<LabeledDataset[D]>();
const generator = ref<AsyncGenerator<{result : boolean, predicted : number; truth : number}, void>>();
const generator = ref<AsyncGenerator<{result : boolean, predicted : string | number; truth : number}, void>>();
const tested = ref<Tested[D]>();
const visitedSamples = computed<number>(() => {
Expand All @@ -207,10 +208,10 @@ const visitedSamples = computed<number>(() => {
}
});
const confusionMatrix = computed<{labels : Map<number, string>, matrix : number[][]} | undefined>(() => {
if (tested.value === undefined) return undefined;
// const labels = Set<number>(); // l'idéal serait de tous les avoir en one try, sinon c'est dégueulasse
// const mapLabels = Map<number, string>();
const confusionMatrix = computed<{ labels: Map<number, string>; matrix: { [key: string]: { [key: string]: number } } }>(() => {
if (tested.value === undefined) {
return { labels: Map<number, string>(), matrix: {} };
}
let labels : string[] = [];
switch (props.task.trainingInformation.dataType) {
case "image" :
Expand All @@ -220,21 +221,28 @@ const confusionMatrix = computed<{labels : Map<number, string>, matrix : number[
labels = ["0", "1"]; // binary classification
break;
case "text" :
return undefined;
default: {
const _: never = props.task.trainingInformation;
return { labels: Map<number, string>(), matrix: {} }; default: {
const _: never = props.task.trainingInformation;
throw new Error("should never happen");
}
}
const size = labels.length;
// Initialize the confusion matrix
const matrix = Array.from({ length: size }, () => Array(size).fill(0));
const matrix: { [key: string]: { [key: string]: number } } = {};
// Initialize the confusion matrix
labels.forEach((label) => {
matrix[label.toString()] = {};
labels.forEach((innerLabel) => {
matrix[label.toString()][innerLabel.toString()] = 0;
});
});
switch (props.task.trainingInformation.dataType) {
case "image":
(tested.value as Tested["image"]).map(
( {output} ) => matrix[output.truth][output.predicted] = matrix[output.truth][output.predicted] + 1,
( {output} ) => matrix[output.label][output.predicted] = matrix[output.label][output.predicted] + 1,
);
break;
//case "text":
Expand All @@ -250,6 +258,8 @@ const confusionMatrix = computed<{labels : Map<number, string>, matrix : number[
}
}
const mapLabels = Map(labels.map((label, index) => [index, label]));
console.log(mapLabels)
console.log(matrix)
return {labels : mapLabels, matrix : matrix};
})
Expand Down Expand Up @@ -341,7 +351,7 @@ async function startImageTest(
output: {
label: label,
correct: result,
predicted: predicted,
predicted: String(predicted),
truth : truth,
},
});
Expand Down Expand Up @@ -388,7 +398,7 @@ async function startTabularTest(
output: {
truth: truth,
correct: result,
predicted : predicted,
predicted : Number(predicted),
label : truth_label,
},
});
Expand Down

0 comments on commit 14918ef

Please sign in to comment.