Writing Robust Concurrency Tests In Go Using a CountDownLatch

Recently, while writing tests involving concurrency in Go, I encountered the common situation of needing to wait for a series of events across multiple goroutines to complete before proceeding with further code execution.

For instance, suppose you’re testing a server using some mock clients. You start up the clients and wait til they’ve completed their work, and then check the outcomes of client operations. One way to ‘wait’ for the completion of the clients is to simply sleep in the main goroutine, as follows:

    var clients []*MyMockClient
    for i := 0; i < 10; i++ {
        // fire up each client
        client := NewMockClient("myServerUri")
        clients = append(clients, client)

        go func () {
            // do work
            // ...
            client.DoWork()
        }()
    }

    // wait 10 seconds for clients to complete work
    time.Sleep(10*time.Second)

    // now check client results
    for _, client := range clients {
        err := client.CheckResults()
        if err != nil {
            // fail the test
        }
    }

One of the downsides of using time.Sleep like this is that it’s hard to predictably know how long to sleep. If the sleep interval is too short, you may proceed to checking the results of the clients before they’re ready to be evaluated. Sleep too long and you end up with a long running test that takes more time than it should – an annoyance when repeatedly running the same test set.

In Go, the sync.WaitGroup concurrency primitive can be used to wait for a series of events to complete. Here’s how it works:

    var clients []*MyMockClient
    var wg sync.WaitGroup

    for i := 0; i < 10; i++ {
        // fire up each client
        client := NewMockClient("myServerUri")
        clients = append(clients, client)

        // add 1 to the WaitGroup counter
        wg.Add(1)
        go func () {

            // decrement the WaitGroup counter when the function completes
            defer wg.Done()
            // do work
            // ...
            client.DoWork()

        }()
    }

    // wait for all clients to finish (i.e. WaitGroup counter hits 0)
    wg.Wait()

    // now check client results
    for _, client := range clients {
        err := client.CheckResults()
        if err != nil {
            // fail the test
        }
    }

Essentially all goroutines share a WaitGroup object. Each child goroutine signals its own completion by calling Done. The main goroutine calls Wait on the WaitGroup object and is notified when all the child goroutines have finished.

This works well in general, and avoids the problems of using a simple sleep. But it makes the assumption that all the clients will actually complete their work. What if something goes wrong, and one of the clients gets blocked and never calls WaitGroup.Done()? This may very well happen in a test scenario. In order to pass the test, your test has to finish in the first place.

Alternatively, suppose you’re testing for performance, and you expect the clients to all finish within a certain window of time.

In that case, what you need is to be able to Wait on a WaitGroup with a timeout. The timeout you provide is a worst-case timeout – if the Wait doesn’t complete before the timeout expires, then you can assume the test has failed.

sync.WaitGroup doesn’t natively provide the ability to Wait with a timeout. This stackoverflow thread talks about using a helper function to get around this. Something like this:

// Return true if WaitGroup.Wait() completes before timeout, false otherwise
func WaitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool {
    ch := make (chan struct{})

    go func () {
        defer close(ch)
        wg.Wait()
    }()

    select {
    case <-ch:
        return true
    case <-time.After(timeout):
        return false
    }
}

This function will work but the code doesn’t smell right. If the goroutine never comes out of Wait, then you’ve leaked a goroutine. Moreover, waiting with a timeout shouldn’t require a separate helper function – it should be a first-class method of the WaitGroup object.

Using A CountDownLatch

To get around the lack of waiting with a timeout in sync.WaitGroup, I created another type, called CountDownLatch, which is available here. It’s modeled after the CountDownLatch primitive in the java.util.concurrent package. Now the above code would look as follows:

    var clients []*MyMockClient
    numClients := 10
    latch := congo.NewCountDownLatch(uint(numClients))

    for i := 0; i < numClients; i++ {
        // fire up each client
        client := NewMockClient("myServerUri")
        clients = append(clients, client)

        go func () {

            // count down on the latch when the function is complete
            defer latch.CountDown()
            // do work
            // ...
            client.DoWork()

        }()
    }

    // wait up to 5 seconds for all clients to finish (i.e. latch counter hits 0)
    finished := latch.WaitTimeout(5*time.Second)
    if !finished {
        // fail the test
    } else {
        // now check client results
        for _, client := range clients {
            err := client.CheckResults()
            if err != nil {
                // fail the test
            }
        }
    }

CountDownLatch accomplishes the same goal as WaitGroup, but now you can bail out after a set timeout if needed. The above code gracefully marks the test as failed if the clients don’t complete their work within a 5 second timeout window.

Tracking Latch Progress with Count()

CountDownLatch comes with a few additional perks beyond waiting with a timeout. One of these is a Count method that returns the current remaining count. This can be useful in performance monitoring scenarios where you want to report on the progress of latch count down completion, and possibly take action programmatically to speed up or slow down latch progress. The following code reports the progress of latch countdown completion:

    var clients []*MyMockClient
    numClients := 10
    latch := congo.NewCountDownLatch(uint(numClients))

    for i := 0; i < numClients; i++ {
        // fire up each client
        client := NewMockClient("myServerUri")
        clients = append(clients, client)

        go func () {

            // count down on the latch when the function is complete
            defer latch.CountDown()
            // do work
            // ...
            client.DoWork()

        }()
    }

    // report progress every 500 milliseconds for 10 iterations
    for i := 0; i < 10 && !latch.WaitTimeout(500*time.Millisecond); i++ {
        fmt.Println("Latch remaining count:", latch.Count())
    }
Conclusion

CountDownLatch overcomes some of the limitations in sync.WaitGroup to help with robust testing, and opens up other usages such as performance monitoring.

For more details, you can refer to the project here or its full documentation here.

Leave a Reply

Your email address will not be published.