diff --git a/src/classifier.js b/src/classifier.js index 1ebd7f53..0ea120f3 100644 --- a/src/classifier.js +++ b/src/classifier.js @@ -38,9 +38,9 @@ async function main() { for (const path of paths) { try { - let results = (await model.inference(path, { - topK: 5 - })).result + let results = await model.inference(path, { + topK: 7 + }) const labels = [] results = results @@ -61,7 +61,7 @@ async function main() { return false } const threshold = result.rule.threshold - if (result.precision < threshold) { + if (result.probability < threshold) { return false } return true @@ -92,13 +92,13 @@ async function main() { cat_probabilities[category] = 0 } if (!(category in cat_thresholds)) { - cat_thresholds[category] = 0 + cat_thresholds[category] = 1 } if (!(category in cat_count)) { cat_count[category] = 0 } - cat_probabilities[category] += result.precision - cat_thresholds[category] = cat_thresholds[category] < result.rule.threshold ? result.rule.threshold : cat_thresholds[category] + cat_probabilities[category] += result.probability + cat_thresholds[category] += result.rule.threshold**2 cat_count[category]++ }) } @@ -108,7 +108,7 @@ async function main() { if (cat_count[category] <= 1) { return false } - return probability >= cat_thresholds[category] ** 2 + 0.15 + return probability >= (cat_thresholds[category]/cat_count[category])**(1/2) }) .forEach(([category]) => { labels.push(category) diff --git a/src/rules.yml b/src/rules.yml index da47657b..f29dbfcc 100644 --- a/src/rules.yml +++ b/src/rules.yml @@ -151,7 +151,7 @@ rapeseed: fashion: label: portrait - threshold: 0.2 + threshold: 0.27 categories: - portrait @@ -2037,7 +2037,7 @@ gibbon: monkey: label: monkey - threshold: 0.55 + threshold: 0.7 priority: 2 categories: - animal @@ -2207,7 +2207,7 @@ puffer: instrument: label: instrument context: tool - threshold: 1 + threshold: 0.8 priority: -2 categories: - music @@ -2603,7 +2603,7 @@ stone wall: desktop computer: label: computer - threshold: 0.3 + threshold: 0.4 categories: - office @@ -2612,7 +2612,7 @@ dial telephone: threshold: 0.25 clock: - threshold: 0.6 + threshold: 0.5 label: clock analog clock: @@ -2622,7 +2622,8 @@ wall clock: see: clock digital clock: - see: clock + threshold: 1 + label: clock digital watch: see: clock @@ -2817,7 +2818,7 @@ harp: tool: label: tool - threshold: 0.5 + threshold: 0.75 priority: -1 hatchet: @@ -2836,7 +2837,7 @@ honeycomb: jack-o'-lantern: label: pumpkin - threshold: 0.25 + threshold: 0.55 categories: - vegetables @@ -3025,7 +3026,7 @@ tricycle: motor scooter: label: scooter - threshold: 0.35 + threshold: 0.4 categories: - vehicle @@ -3191,7 +3192,7 @@ sunscreen: screen: label: screen - threshold: 0.3 + threshold: 0.6 perfume: label: bottle @@ -3411,7 +3412,7 @@ sliding door: snorkel: label: diving - threshold: 0.4 + threshold: 0.7 categories: - water @@ -3495,7 +3496,7 @@ strainer: threshold: 0.71 streetcar: - threshold: 0.28 + threshold: 0.35 categories: - train - vehicle @@ -3879,7 +3880,7 @@ custard apple: pomegranate: label: fruit - threshold: 0.23 + threshold: 0.3 categories: - food @@ -3949,7 +3950,7 @@ alp: cliff: label: landscape - threshold: 0.15 + threshold: 0.10 promontory: see: cliff