## Digit recognition in F# with k-nearest neighbours

In this post, I’ll step through using a simple Machine Learning algorithm called k-nearest neighbours (KNN) to perform handwritten digit recognition. This is another of the Hello Worlds of ML – it’s how I got introduced to it (via Mathias Brandewinder’s Digit Recogniser Dojo), and also one I’ve run as a dojo a few times myself. It’s based on the Kaggle learning competition of the same name.

KNN is a lazy classification algorithm, which is used to determine what an unknown example of something is based on it’s similarity to known examples. Specifically, it finds the k most similar (nearest) examples, and classifies it according to what they are.

Here’s a visualisation, in which we’re trying to work out what the unknown green thing is (blue square or red triangle?)

Based on the three nearest examples, we would classify the unknown item as a red triangle.

(Interestingly, if we’d chosen a k of 5, we would say it’s a blue square – we’ll discuss how to approach that later).

Th dataset we’ll be running this on is (a subset of) the MNIST handwritten digits, which consists of several thousand 28×28 pixel greyscale images that look a bit like this:

So how do you determine the ‘distance’ between two such images? A common way is to use Euclidean distance, which when applied in this situation involves comparing each pixel and summing the differences between them – or, to be more precise, summing the squares of the distances.

(The actual equation, which you can see above, uses the square root of the sum of the squares – but the square root part doesn’t make any difference in this case). So if two images were identical, the distance would be 0. If two of the pixels were different, by 20 and 50 respectively, the distance would be 20^2 + 50^2 = 400 + 2500 = 2900 (still very similar).

OK, lesson over, let’s get implementing! I’ll be doing this in F#, but it’s pretty similar in C# if you use a lot of LINQ.

I’ve put the source code on github. (You can find a basic version of it in C#, as well as several other languages, in my Digit Recogniser Dojo repository). I’m using arrays for most of this, which although aren’t the most natural thing to use from F#, give better performance – which can make a fair difference when working with big datasets.

I’m going to be using 5000 digits, which is enough to get pretty good results without taking too long. They come formatted as a CSV, with each record having a label (e.g. which digit it represents) followed by each of the 784 pixels (28×28) represented by a number from 0-255:

First of all, let’s read the records, and get rid of the header:

let dataLines = File.ReadAllLines(__SOURCE_DIRECTORY__ + """\trainingsample.csv""").[1..]

We can check our progress using F# Interactive, and see that we now have an array of strings containing a record each:

val dataLines : string [] = [|"1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0"+[1669 chars]; "0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0"+[1925 chars];

We can then split the lines up and parse the numbers:

let dataNumbers = dataLines |> Array.map (fun line -> line.Split(',')) |> Array.map (Array.map (int))

Now we have an array of arrays of integers:

val dataNumbers : int [] [] = [|[|1; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; ...|]; [|0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0;

Working with arrays of arrays doesn’t make things too clear, so let’s create a type to store the records in:

type DigitRecord = { Label:int; Pixels:int[] } let dataRecords = dataNumbers |> Array.map (fun record -> {Label = record.[0]; Pixels = record.[1..]})

So now we have an array of DigitRecords:

val dataRecords : DigitRecord [] = [|{Label = 1; Pixels = [|0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; ...|];};

Thinking ahead a bit, we’re going to need some way of testing our algorithm to see how good it is (ans also, as described shortly, to work out what value to use for ‘k’). The typical approach for this is to split the known dataset into three – one to train the algorithm, one to choose what options to use, and another to test the accuracy of the final system:

let trainingSet = dataRecords.[..3999] let crossValidationSet = dataRecords.[4000..4499] let testSet = dataRecords.[4500..]

We’re nearly at the fun part, so we’ll soon need a way of measuring the distance between two digits, as described earlier. F#’s map2 function comes in very handy for this, applying a function two two same-sized arrays at the same time:

let distanceTo (unknownDigit:int[]) (knownDigit:DigitRecord) = Array.map2 ( fun unknown known -> let difference = unknown-known int64 (difference * difference) ) unknownDigit knownDigit.Pixels |> Array.sum

For the unknown parameter we’ll take the raw pixels, as if we ever come to use this on real data we won’t have a label for it (so it won’t be in a DigitRecord).

Let’s give it a test run:

> {Label=1;Pixels=[|120;150|]} |> distanceTo [|100;100|];; val it : int64 = 2900L

Looks about right!

Now we get to the real meat of it, classifying an unknown digit based on the k nearest known examples. In this function, we compare the digit against he whole training set, take the k closest, and count which label occurs the most times:

let classifyByNearest k (unknownDigit:int[]) = trainingSet |> Array.sortBy (distanceTo unknownDigit) |> Seq.take k |> Seq.countBy (fun digit -> digit.Label ) |> Seq.maxBy (fun (label,count) -> count) |> fun (label,count) -> label

F# is particularly good at this kind of code, although you can do something not too much more complicated in C# with a good helping of LINQ (which is basically a functional library).

And with that, we have everything we need to start making predictions! Let’s give it a blast. We need to chose a value for ‘k’ – let’s just use 1 for now:

testSet.[..4] |> Array.iter (fun digit -> printfn "Actual: %d, Predicted: %d" digit.Label (digit.Pixels |> classifyByNearest 1))

Actual: 5, Predicted: 6 Actual: 3, Predicted: 3 Actual: 9, Predicted: 9 Actual: 8, Predicted: 8 Actual: 7, Predicted: 7

Four out of five, not too shabby! Let’s write a function to find out what the accuracy is for the whole validation set. We’re taking the dataset to calculate the accuracy on as a parameter because we’ll use it in two different ways, which you’ll see shortly.

let calculateAccuracyWithNearest k dataSet = dataSet |> Array.averageBy (fun digit -> if digit.Pixels |> classifyByNearest k = digit.Label then 1.0 else 0.0)

So, let’s give it a try! We still haven’t yet worked out what value to use for k, so let’s stick with 1 for now.

> testSet |> calculateAccuracyWithNearest 1;; val it : float = 0.93

93%, not bad for 30-odd lines of code!

But what about k, how do we decide what value to use? This is where the cross-validation set comes in. The idea is that you use the cross-validation set to test the various options you have (in this case, different values of k) and find the best one(s). You can then use the test set to get a final figure for the accuracy of your algorithm. The reason we have separate validation and test sets, instead of just using the same test set for both, is to help get a final figure which is more representative of what our algorithm will get when used on real data.

If we just took the accuracy as measured by the validation set when finding the best options, we’d have an artificially high figure, as we’ve specifically chosen the options to maximise that number. The chances are, when we run it on real data, the accuracy will be lower.

To begin with, we can try a few values for k and plot the result (with FSharp.Charting), to see if we’re in the right ballpark:

let predictionAccuracy = [1;3;9;27] |> List.map (fun k -> (k, crossValidationSet |> calculateAccuracyWithNearest k)) Chart.Line(predictionAccuracy)

val predictionAccuracy : (int * float) list = [(1, 0.93); (3, 0.936); (9, 0.94); (27, 0.91)]

As you can hopefully see, the accuracy seems to peak somewhere around 10 then starts to drop off. So let’s try all of the values in this range to find the best one (this make take a while – time to get a coffee!):

let bestK = [1..20] |> List.maxBy (fun k -> crossValidationSet |> calculateAccuracyWithNearest k)

Turns out the best value for k, at least according to the validation set, is 6. On the validation set, that gets us a heady 94% accuracy:

> crossValidationSet |> calculateAccuracyWithNearest bestK;; val it : float = 0.94

As previously described, that’s likely to be a little optimistic, so lets get a final measure using the test set:

> testSet |> calculateAccuracyWithNearest bestK;; val it : float = 0.926

Final result: 92.6%!

Not a bad result for a simple algorithm. In fact in this case, we actually get a better result on the test set using a k of 1. With so little in it, we might decide just use a ‘1-nearest neighbours’ algorithm, which would be simpler still.

Like many Machine Learning systems, this also gets better the more data you use – I submitted code based on this to the Kaggle contest, which has a dataset of 50,000 digits, and got an accuracy of around 97%. More data would also help with the disparity between the validation and test set, and mean the options we chose using the validation set would be more likely to be optimum for that set as well.

KNN can be used in several situations – you just need some way of measuring of similarity between two things. For example, you could use Levenshtein distance to compare strings and find the best match for somebody’s name, or physical distance to find the nearest post office for a given home.

An interesting alternative is the weighted nearest neighbour algorithm, which potentially measures the distance to all neighbours but weighs their influence based on their distance (e.g. close ones count for a lot when taking the vote, distant ones hardly at all).

So there we have it, digit recognition with KNN. I hope you found it interesting. May your machines learn well!