Archive

Author Archive

Make it run, make it right: The three implementation strategies of TDD

I’ve been doing TDD for a while now, as well as running TDD-based code dojos. I really like it as a method of programming, and use it for the majority of the production code I write.

However, there’s one important area I’ve never got quite straight in my head, around the point at which you go from ‘making it pass quickly’ to ‘implementing it properly’. I’ve always followed the rule (picked up at some past dojo) to make the test pass in the simplest way possible, then refactor if necessary before continuing with the next test. Making it pass simply often involves returning a hard-coded value, but I’ve never fully got a handle on when that value should get replaced with the ‘real’ implementation, and related to that, when certain things should be refactored.

Following some recommendations, I finally read Kent Beck’s Test-Driven Development By Example, an influential book on TDD published back in 2002. Yay!  Finally, the last piece of the puzzle was revealed. Well OK, probably not the last, but an important one…

There are several useful guidelines in the book – for example he specifies the TDD cycle as ‘Write a test, Make it run, Make it right’ (similar to the more commonly used ‘red, green, refactor’ – not sure which came first).  The one I’m focussing on today is his strategy for the implementation of functionality, i.e. the ‘make it run’ and ‘make it right’ phases.

To take the descriptions from the book, the three options you have are:

Fake It – Return a constant and gradually replace constants with variables until you have real code

Obvious Implementation – Type in the real implementation

Triangulation – Only generalize code when we have two examples or more

This is already getting interesting – we now have three well-defined methods of proceeding.  Looking back, I think part of my problem has been trying to stick to one rule in the many and varied situations that arise when coding a solution with TDD.

Lets have a look at these in a bit more detail, with some example code.  The examples are based on the Checkout kata, which involves implementing a supermarket checkout system which calculates the price of items bought. I haven’t included every step, but hopefully you can fill in the blanks.

Fake It

One of the guiding principles behind these strategies is that you want to get a passing test as quickly as possible.  If that means writing some dodgy code, then so be it – it’s not a concern, because one of the other principles is that the code should be constantly and aggressively refactored.  So get the test passing, then you have a safety net of passing tests which enable you to refactor to your heart’s content.

Often, the quickest way to get a test passing is to return a hard-coded value:

[Test]
public void scanning_single_item_gives_single_item_price() {
    var till = new Till();
    till.Scan("Apple");
    Assert.That(till.Total, Is.EqualTo(10));
}

public class Till {
    public void Scan(string product) {}

    public int Total {
        get {
            return 10;
        }
    }
}

Once that’s done, we can immediately get to refactoring. Kent’s main goal of refactoring is to remove duplication anywhere you can spot it, which he believes generally leads to good design. So where’s the duplication here? All we’re doing is returning a constant. But look closely and you’ll see it’s the same hard-coded constant defined in the tests and the code. Duplication.

So how do we get rid of it? One way would be to have the checkout take the value from the test:

[Test]
public void scanning_single_item_gives_single_item_price() {
    var applePrice = 10;
    var till = new Till(applePrice);
    till.Scan("Apple");
    Assert.That(till.Total, Is.EqualTo(applePrice));
}

public class Till {
    private int _applePrice;
    public Till(int applePrice) {
        _applePrice = applePrice;
    }

    public void Scan(string product) { }

    public int Total {
        get {
            return _applePrice;
        }
    }
}

Well, not how I’d usually implement it, but that makes sense I guess – the checkout would be getting it’s prices from somewhere external, and passing them in like this helps keep the checkout code nice and self-contained, and adhering to the open/closed principle.  Maybe removing duplication really can help with a good design..

Of course, the checkout would need all prices passing in, not just the price for one product.  But as usual with TDD, we’re taking small steps, and only implementing what we need to based on the current set of tests.

Obvious Implementation

One of the ‘rules’ of TDD which I make people follow in dojos has been to always start with the simplest possible implementation.  This is a concept which mightily annoys some experienced devs when they first try TDD – ‘Why return a hard-coded value’, they ask, ‘when I know what the implementation’s going to be?’.

Well, they will be most pleased with the second option, which states that if the implementation is obvious and quick – just type it in!

So if for our next test we end up with

        const int applePrice = 10;
        const int bananaPrice = 15;
        Till till;

        [SetUp]
        public void setup() {
            till = new Till(new Dictionary<string, int> {
                {"Apple", applePrice},
                {"Banana", bananaPrice}
            });
        }

        [Test]
        public void scanning_different_item_gives_correct_price() {
            till.Scan("Banana");

            Assert.That(till.Total, Is.EqualTo(bananaPrice));
        }

If we’re feeling plucky, we might just go for it:

    public class Till {
        private Dictionary<string, int> _prices;
        private string _scannedItem;

        public Till(Dictionary<string, int> prices) {
            _prices = prices;
        }

        public void Scan(string product) {
            _scannedItem = product;
        }

        public int Total {
            get {
                return _prices[_scannedItem];
            }
        }
    }

However, if you have any doubts, get a failing test, or it’s just taking too long – you’ve taken the wrong option. Back out your changes, and go back to Fake It.

Triangulation

The last option is the least favoured by Mr Beck, reserved for those times when he’s ‘completely unsure of how to refactor’, and ‘the design thoughts just aren’t coming’.  For me, going by the rules I’ve learned in past dojos, it’s one I quite often end up using.  Essentially it involves using using Fake It to get a test passing, but instead of going straight to the real implementation, entering another test to force you into writing a generalised solution.

Now, we’re starting to scan multiple items:

        [Test]
        public void scanning_multiple_items_gives_sum_of_prices() {
            till.Scan("Apple");
            till.Scan("Banana");

            Assert.That(till.Total, Is.EqualTo(applePrice + bananaPrice));
        }

We have to update the code a little, but can fake the important bit:

    public class Till {
        private Dictionary<string, int> _prices;
        private List<string> _scannedItems = new List<string>();

        public Till(Dictionary<string, int> prices) {
            _prices = prices;
        }

        public void Scan(string product) {
            _scannedItems.Add(product);
        }

        public int Total {
            get {
                if (_scannedItems.Count == 2) {
                    return 25;
                }
                return _prices[_scannedItems[0]];
            }
        }
    }

Obviously, checking the number of scanned items isn’t going to work for long, but instead of refactoring immediately we can write another test which will force us to change it.

        [Test]
        public void scanning_different_multiple_items_gives_sum_of_prices() {
            till.Scan("Banana");
            till.Scan("Coconut");

            Assert.That(till.Total, Is.EqualTo(bananaPrice + coconutPrice));
        }

And now we enter the general solution:

        public int Total {
            get {
                return _scannedItems.Sum(item => _prices[item]);
            }
        }

Since using these strategies, I have also found that Triangulation isn’t needed much, as if you stick strongly to the ‘refactor away duplication’ rule, a single Fake It with immediate refactoring usually does the job.

Conclusion

It’s fair to say that I haven’t mastered TDD yet, even after so long, but I’m pretty happy with it now, thanks to the guidance like this. I still find situations where I’m unsure how to proceed, but they are getting rarer. Trying to follow these three strategies has certainly helped, and the ‘refactor to remove duplication’ concept by itself has been very useful. I hope you get some benefit from them too.

Categories: C#, TDD, Unit Testing

Digit recognition in F# with k-nearest neighbours

May 29, 2014 1 comment

In this post, I’ll step through using a simple Machine Learning algorithm called k-nearest neighbours (KNN) to perform handwritten digit recognition.  This is another of the Hello Worlds of ML – it’s how I got introduced to it (via Mathias Brandewinder’s Digit Recogniser Dojo), and also one I’ve run as a dojo a few times myself. It’s based on the Kaggle learning competition of the same name.

KNN is a lazy classification algorithm, which is used to determine what an unknown example of something is based on it’s similarity to known examples.  Specifically, it finds the k most similar (nearest) examples, and classifies it according to what they are.

Here’s a visualisation, in which we’re trying to work out what the unknown green thing is (blue square or red triangle?)

KNN visualisation

 

Based on the three nearest examples, we would classify the unknown item as a red triangle.

(Interestingly, if we’d chosen a k of 5, we would say it’s a blue square – we’ll discuss how to approach that later).

Th dataset we’ll be running this on is (a subset of) the MNIST handwritten digits, which consists of several thousand 28×28 pixel greyscale images that look a bit like this:

MNIST handwritten digits

So how do you determine the ‘distance’ between two such images?  A common way is to use Euclidean distance, which when applied in this situation involves comparing each pixel and summing the differences between them – or, to be more precise, summing the squares of the distances.

Euclidean Distance

(The actual equation, which you can see above, uses the square root of the sum of the squares – but the square root part doesn’t make any difference in this case). So if two images were identical, the distance would be 0.  If two of the pixels were different, by 20 and 50 respectively, the distance would be 20^2 + 50^2 = 400 + 2500 = 2900 (still very similar).

OK, lesson over, let’s get implementing!  I’ll be doing this in F#, but it’s pretty similar in C# if you use a lot of LINQ.

I’ve put the source code on github.  (You can find a basic version of it in C#, as well as several other languages, in my Digit Recogniser Dojo repository).  I’m using arrays for most of this, which although aren’t the most natural thing to use from F#, give better performance – which can make a fair difference when working with big datasets.

I’m going to be using 5000 digits, which is enough to get pretty good results without taking too long.  They come formatted as a CSV, with each record having a label (e.g. which digit it represents) followed by each of the 784 pixels (28×28) represented by a number from 0-255:

digits csv

First of all, let’s read the records, and get rid of the header:

let dataLines =
    File.ReadAllLines(__SOURCE_DIRECTORY__ + """\trainingsample.csv""").[1..]

We can check our progress using F# Interactive, and see that we now have an array of strings containing a record each:

val dataLines : string [] =
  [|"1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0"+[1669 chars];
    "0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0"+[1925 chars];

We can then split the lines up and parse the numbers:

let dataNumbers =
    dataLines
    |> Array.map (fun line -> line.Split(','))
    |> Array.map (Array.map (int))

Now we have an array of arrays of integers:

val dataNumbers : int [] [] =
  [|[|1; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0;
      0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0;
      0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0;
      0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0;
      0; 0; 0; 0; ...|];
    [|0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0;

Working with arrays of arrays doesn’t make things too clear, so let’s create a type to store the records in:

type DigitRecord = { Label:int; Pixels:int[] }

let dataRecords =
    dataNumbers
    |> Array.map (fun record -> {Label = record.[0]; Pixels = record.[1..]})

So now we have an array of DigitRecords:

val dataRecords : DigitRecord [] =
  [|{Label = 1;
     Pixels =
      [|0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0;
        0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0;
        0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0;
        0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0; 0;
        0; 0; 0; 0; ...|];};

Thinking ahead a bit, we’re going to need some way of testing our algorithm to see how good it is (ans also, as described shortly, to work out what value to use for ‘k’). The typical approach for this is to split the known dataset into three – one to train the algorithm, one to choose what options to use, and another to test the accuracy of the final system:

let trainingSet = dataRecords.[..3999]
let crossValidationSet = dataRecords.[4000..4499]
let testSet = dataRecords.[4500..]

We’re nearly at the fun part, so we’ll soon need a way of measuring the distance between two digits, as described earlier. F#’s map2 function comes in very handy for this, applying a function two two same-sized arrays at the same time:

let distanceTo (unknownDigit:int[]) (knownDigit:DigitRecord) =
    Array.map2 (
        fun unknown known ->
            let difference = unknown-known
            int64 (difference * difference)
        ) unknownDigit knownDigit.Pixels
    |> Array.sum

For the unknown parameter we’ll take the raw pixels, as if we ever come to use this on real data we won’t have a label for it (so it won’t be in a DigitRecord).

Let’s give it a test run:

> {Label=1;Pixels=[|120;150|]} |> distanceTo [|100;100|];;
val it : int64 = 2900L

Looks about right!

Now we get to the real meat of it, classifying an unknown digit based on the k nearest known examples. In this function, we compare the digit against he whole training set, take the k closest, and count which label occurs the most times:

let classifyByNearest k (unknownDigit:int[]) =
    trainingSet
    |> Array.sortBy (distanceTo unknownDigit)
    |> Seq.take k
    |> Seq.countBy (fun digit -> digit.Label )
    |> Seq.maxBy (fun (label,count) -> count)
    |> fun (label,count) -> label

F# is particularly good at this kind of code, although you can do something not too much more complicated in C# with a good helping of LINQ (which is basically a functional library).

And with that, we have everything we need to start making predictions! Let’s give it a blast. We need to chose a value for ‘k’ – let’s just use 1 for now:

testSet.[..4]
|> Array.iter (fun digit ->
    printfn "Actual: %d, Predicted: %d"
        digit.Label
        (digit.Pixels |> classifyByNearest 1))
Actual: 5, Predicted: 6
Actual: 3, Predicted: 3
Actual: 9, Predicted: 9
Actual: 8, Predicted: 8
Actual: 7, Predicted: 7

Four out of five, not too shabby! Let’s write a function to find out what the accuracy is for the whole validation set.  We’re taking the dataset to calculate the accuracy on as a parameter because we’ll use it in two different ways, which you’ll see shortly.

let calculateAccuracyWithNearest k dataSet =
    dataSet
    |> Array.averageBy (fun digit ->
        if digit.Pixels |> classifyByNearest k = digit.Label then 1.0
        else 0.0)

So, let’s give it a try!  We still haven’t yet worked out what value to use for k, so let’s stick with 1 for now.

> testSet |> calculateAccuracyWithNearest 1;;
val it : float = 0.93

93%, not bad for 30-odd lines of code!

But what about k, how do we decide what value to use?  This is where the cross-validation set comes in.  The idea is that you use the cross-validation set to test the various options you have (in this case, different values of k) and find the best one(s).  You can then use the test set to get a final figure for the accuracy of your algorithm.  The reason we have separate validation and test sets, instead of just using the same test set for both, is to help get a final figure which is more representative of what our algorithm will get when used on real data.

If we just took the accuracy as measured by the validation set when finding the best options, we’d have an artificially high figure, as we’ve specifically chosen the options to maximise that number.  The chances are, when we run it on real data, the accuracy will be lower.

To begin with, we can try a few values for k and plot the result (with FSharp.Charting), to see if we’re in the right ballpark:

let predictionAccuracy =
    [1;3;9;27]
    |> List.map (fun k ->
        (k, crossValidationSet |> calculateAccuracyWithNearest k))

Chart.Line(predictionAccuracy)
val predictionAccuracy : (int * float) list =
  [(1, 0.93); (3, 0.936); (9, 0.94); (27, 0.91)]

accuracy chart

As you can hopefully see, the accuracy seems to peak somewhere around 10 then starts to drop off. So let’s try all of the values in this range to find the best one (this make take a while – time to get a coffee!):

let bestK =
    [1..20]
    |> List.maxBy (fun k ->
        crossValidationSet |> calculateAccuracyWithNearest k)

Turns out the best value for k, at least according to the validation set, is 6. On the validation set, that gets us a heady 94% accuracy:

> crossValidationSet |> calculateAccuracyWithNearest bestK;;
val it : float = 0.94

As previously described, that’s likely to be a little optimistic, so lets get a final measure using the test set:

> testSet |> calculateAccuracyWithNearest bestK;;
val it : float = 0.926

Final result: 92.6%!

Not a bad result for a simple algorithm. In fact in this case, we actually get a better result on the test set using a k of 1. With so little in it, we might decide just use a ‘1-nearest neighbours’ algorithm, which would be simpler still.

Like many Machine Learning systems, this also gets better the more data you use – I submitted code based on this to the Kaggle contest, which has a dataset of 50,000 digits, and got an accuracy of around 97%. More data would also help with the disparity between the validation and test set, and mean the options we chose using the validation set would be more likely to be optimum for that set as well.

KNN can be used in several situations – you just need some way of measuring of similarity between two things.  For example, you could use Levenshtein distance to compare strings and find the best match for somebody’s name, or physical distance to find the nearest post office for a given home.

An interesting alternative is the weighted nearest neighbour algorithm, which potentially measures the distance to all neighbours but weighs their influence based on their distance (e.g. close ones count for a lot when taking the vote, distant ones hardly at all).

So there we have it, digit recognition with KNN.  I hope you found it interesting.  May your machines learn well!

Categories: F#, Machine Learning

Web requests in F# now easy! Introducing Http.fs

November 15, 2013 3 comments

TL;DR

I’ve made a module which makes HTTP calls (like downloading a web page) easy, available now on GitHub – Http.fs

Introduction

I had a project recently which involved making a lot of HTTP requests and dealing with the responses.  F# being my current language of choice, I was using that.  Unfortunately, .Net’s HttpWebRequest/Response aren’t that nice to use from F# (or C#, frankly).

For example, here’s how you might make an HTTP Post, from F# Snippets:

open System.Text
open System.IO
open System.Net

let url = "http://posttestserver.com/post.php"

let req = HttpWebRequest.Create(url) : ?> HttpWebRequest
req.ProtocolVersion req.Method <- "POST"

let postBytes = Encoding.ASCII.GetBytes("fname=Tomas&lname=Petricek")
req.ContentType <- "application/x-www-form-urlencoded";
req.ContentLength let reqStream = req.GetRequestStream()
reqStream.Write(postBytes, 0, postBytes.Length);
reqStream.Close()

let resp = req.GetResponse()
let stream = resp.GetResponseStream()
let reader = new StreamReader(stream)
let html = reader.ReadToEnd()

There are a few things I don’t really like about doing this:

  • It’s a lot of code!
  • You have to mess around with streams
  • The types used are mutable, so not really idiomatic F#
  • You have to set things (e.g. ‘POST’) as strings, so not typesafe
  • It’s not unit testable
  • You have to cast things (e.g. req) to the correct type

In fact there are many other problems with HttpWebRequest/Response which aren’t demonstrated by this sample, including:

  • Some headers are defined, others aren’t (so you have to set them as strings)
  • You need to mess around with the cookie container to get cookies working
  • If the response code is anything but 200-level, you get an exception (!)
  • Getting headers and cookies from the response isn’t pretty

Since then I’ve discovered HttpClient, which does address some of these issues, but it’s still not great to use from F# (and only available in .Net 4.5).

So I started to write some wrapper functions around this stuff, and it eventually turned into:

Http.fs!

Http.fs is a module which contains a few types and functions for more easily working with Http requests and responses from F#. It uses HttpWebRequest/Response under the hood, although these aren’t exposed directly when you use it.

Downloading a single web page is as simple as:

let page = (createRequest Get "http://www.google.com" |> getResponseBody)

And if you want to do something more in-depth, like the example above, that would look like this:

open HttpClient

let response =
  createRequest Post "http://posttestserver.com/post.php"
  |> withBody "fname=Tomas&lname=Petricek"
  |> withHeader (ContentType "application/x-www-form-urlencoded")
  |> getResponse

Then you could access the response elements like so:

response.StatusCode
response.EntityBody.Value
response.Headers.[Server]

And of course, it has asynchronous functions to let you do things like download multiple pages in parallel:

["http://news.bbc.co.uk"
 "http://www.wikipedia.com"
 "http://www.stackoverflow.com"]
|> List.map (fun url -> createRequest Get url |> getResponseBodyAsync)
|> Async.Parallel
|> Async.RunSynchronously
|> Array.iter (printfn "%s")

There are more details on the GitHub page. The project also contains a sample application which shows how it can be used and tested.

So if you’re using F# and want to make a complex HTTP request – or just download a web page – check out Http.fs!

Update

This is now available on NuGet.  To install:

PM> install-package Http.fs  
Categories: F#, HTTP

Hello Neurons – ENCOG Neural Network XOR example in F#

November 14, 2013 1 comment

I’ve been playing with Machine Learning lately, starting with Abhishek Kumar’s Introduction to Machine Learning video on PluralSight.

This video guides you though using the ENCOG library (available on NuGet) to build a simple neural network for the XOR (eXclusive OR) logic table, which is the ‘Hello World’ of Neural Networks.

I’m not going to go into the details of ML or Neural Networks here (I don’t know them, for a start), but one thing I found was that the .Net ENCOG examples were all in C#.  As such, I though I’d post my F# version here. (See the C# version for comparison).

So, without further ado:

open Encog.ML.Data.Basic
open Encog.Engine.Network.Activation
open Encog.Neural.Networks
open Encog.Neural.Networks.Layers
open Encog.Neural.Networks.Training.Propagation.Resilient

let createNetwork() =
    let network = BasicNetwork()
    network.AddLayer( BasicLayer( null, true, 2 ))
    network.AddLayer( BasicLayer( ActivationSigmoid(), true, 2 ))
    network.AddLayer( BasicLayer( ActivationSigmoid(), false, 1 ))
    network.Structure.FinalizeStructure()
    network.Reset()
    network

let train trainingSet (network: BasicNetwork) =
    let trainedNetwork = network.Clone() : ?> BasicNetwork
    let trainer = ResilientPropagation(trainedNetwork, trainingSet)

    let rec trainIteration epoch error =
        match error > 0.001 with
        | false -> ()
        | true -> trainer.Iteration()
                  printfn "Iteration no : %d, Error: %f" epoch error
                  trainIteration (epoch + 1) trainer.Error

    trainIteration 1 1.0
    trainedNetwork

[<EntryPoint>]
let main argv =

    let xor_input =
        [|
            [| 0.0 ; 0.0 |]
            [| 1.0 ; 0.0 |]
            [| 0.0 ; 1.0 |]
            [| 1.0 ; 1.0 |]
        |]

    let xor_ideal =
        [|
            [| 0.0 |]
            [| 1.0 |]
            [| 1.0 |]
            [| 0.0 |]
        |]

    let trainingSet = BasicMLDataSet(xor_input, xor_ideal)
    let network = createNetwork()

    let trainedNetwork = network |> train trainingSet

    trainingSet
    |> Seq.iter (
        fun item ->
            let output = trainedNetwork.Compute(item.Input)
            printfn "Input: %f, %f Ideal: %f Actual: %f"
                item.Input.[0]  item.Input.[1] item.Ideal.[0] output.[0])

    printfn "Press return to exit.."
    System.Console.Read() |> ignore

    0 // return an integer exit code

The main difference over the C# version is that the training iterations are done with recursion instead of looping, and the training returns a new network rather than updating the existing one. Nothing wrong with doing it that way per se, but it gave me a warm feeling inside to make it all ‘functional’.

It may be a while before I create Skynet, but you’ve got to start somewhere..

Categories: F#, Machine Learning

Testing DateTime.Now (and other things) with an Ambient Context

September 30, 2013 Leave a comment

One issue you’re likely to face while writing automated tests is that it can be tricky to check anything related to DateTimes, as they are typically accessed statically using:

var startTime = DateTime.Now;

The approach I’ve typically used is to create some kind of IDateProvider interface, and supply that as a dependency as needed using Dependency Injection.  However, with things used throughout a system like DateTime can be, this becomes quite a chore and make the code more complex with an ever-increasing list of dependencies.

One of the great tips given by Ian Russell at an Agile Yorkshire meetup recently was to use the Ambient Context pattern in this, and other similar situations.

Note: I didn’t write the code below, I found it in various places online, particularly this Stack Overflow question by Mark Seemann, author of Dependency Injection in .Net.  I was having trouble finding the details online, which is why I’ve posted them here.

So, to take full control of DateTime.Now, you need an abstract TimeProvider:

public abstract class TimeProvider {
   private static TimeProvider current =
   DefaultTimeProvider.Instance;

   public static TimeProvider Current {
      get { return TimeProvider.current; }
      set {
         if (value == null) {
            throw new ArgumentNullException("value");
         }
         TimeProvider.current = value;
      }
   }

   public abstract DateTime UtcNow { get; }

   public static void ResetToDefault() {
      TimeProvider.current = DefaultTimeProvider.Instance;
   }
}

A concrete implementation used in your live system:

public class DefaultTimeProvider : TimeProvider {
   private readonly static DefaultTimeProvider instance = 
      new DefaultTimeProvider();

   private DefaultTimeProvider() { }

   public override DateTime UtcNow {
      get { return DateTime.UtcNow; }
   }

   public static DefaultTimeProvider Instance {
      get { return DefaultTimeProvider.instance; }
   }
}

And a mock for testing, for example here set up with Rhino.Mocks:

var fakeTimeProvider = MockRepository.GenerateStub();
fakeTimeProvider.Expect(x => x.UtcNow).Repeat.Once()
   .Return(new DateTime(1999, 12, 31, 11, 59, 59, 999));
fakeTimeProvider.Expect(x => x.UtcNow).Repeat.Once()
   .Return(new DateTime(2000, 1, 1, 0, 0, 0, 0));

For this to work, instead of calling dateTime.Now (or DateTime.UtcNow as we’re using here) from your code, you need to call:

var startTime = TimeProvider.Current.UtcNow;

One last thing – don’t forget to reset the provider after you mock it out, otherwise your mock will remain as the time provider:

[TestCleanup]
public void TearDown() {
   TimeProvider.ResetToDefault();
}

As suggested in the title, this can be used for things other than DateTime, but that’s the only place I’ve tried it so far.

Happy testing!

Categories: Patterns, TDD, Unit Testing
Follow

Get every new post delivered to your Inbox.

Join 246 other followers