-
Notifications
You must be signed in to change notification settings - Fork 0
/
DigitClassifier.fsx
74 lines (60 loc) · 2.23 KB
/
DigitClassifier.fsx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#load "packages/FSharp.Charting.0.82/FSharp.Charting.fsx"
open System
open System.IO
open FSharp.Charting
(******* 1. GETTING SOME DATA *******)
let dataLines =
File.ReadAllLines(__SOURCE_DIRECTORY__ + """\trainingsample.csv""").[1..]
(******* 2. EXTRACTING COLUMNS *******)
let dataNumbers =
dataLines
|> Array.map (fun line -> line.Split(','))
|> Array.map (Array.map (int))
(******* 3. CONVERTING ARRAYS TO RECORDS *******)
type DigitRecord = { Label:int; Pixels:int[] }
let dataRecords =
dataNumbers
|> Array.map (fun record -> {Label = record.[0]; Pixels = record.[1..]})
(******* 4. TRAINING vs VALIDATION DATA *******)
let trainingSet = dataRecords.[..3999]
let crossValidationSet = dataRecords.[4000..4499]
let testSet = dataRecords.[4500..]
(******* 5. COMPUTING DISTANCE *******)
let distanceTo (unknownDigit:int[]) (knownDigit:DigitRecord) =
Array.map2 (
fun unknown known ->
let difference = unknown-known
int64 (difference * difference)
) unknownDigit knownDigit.Pixels
|> Array.sum
(******* 6. THE CLASSIFIER FUNCTION *******)
let classifyByNearest k (unknownDigit:int[]) =
trainingSet
|> Array.sortBy (distanceTo unknownDigit)
|> Seq.take k
|> Seq.countBy (fun digit -> digit.Label )
|> Seq.maxBy (fun (label,count) -> count)
|> fun (label,count) -> label
(******* 7. SEE THE CLASSIFIER IN ACTION *******)
testSet.[..4]
|> Array.iter (fun digit ->
printfn "Actual: %d, Predicted: %d"
digit.Label
(digit.Pixels |> classifyByNearest 3))
(******* 8. EVALUATING THE MODEL AGAINST VALIDATION DATA *******)
let calculateAccuracyWithNearest k dataSet =
dataSet
|> Array.averageBy (fun digit ->
if digit.Pixels |> classifyByNearest k = digit.Label then 1.0
else 0.0)
(******* 9. GET AN IDEA ABOUT A GOOD RANGE OF K *******)
let predictionAccuracy =
[1;3;9;27]
|> List.map (fun k -> (k, crossValidationSet |> calculateAccuracyWithNearest k))
Chart.Line predictionAccuracy
(******* 10. FIND THE BEST K *******)
let bestK =
[1..20]
|> List.maxBy (fun k -> crossValidationSet |> calculateAccuracyWithNearest k)
(******* 11. THE FINAL RESULT! *******)
testSet |> calculateAccuracyWithNearest bestK