Improving Code Confidently with Test-Driven Development

Last updated October 21st, 2021

Like many developers, when I was first introduced to Test-Driven Development (TDD) I didn't understand it at all. I didn't have a clue (nor the patience) on how to go about writing tests first. So, I didn't put much effort into it and went about my normal flow of writing code before adding tests to cover it. This continued for a number of years.

I co-founded typless back in 2017 where I currently lead the engineering efforts. We had to move quickly in the beginning so we raked up quite a bit of tech debt. At the time, the platform itself was backed by a large, monolithic Django app. Debugging was very difficult. The code was hard to read and even harder to change. We'd squash one bug and three more would replace it. With that, it was time for me to give TDD another shot. Everything I had read about it indicated that it would help -- and it did. I finally saw one of the core side-effects from TDD: It makes code changes much easier.

Contents

Software is a living thing

One of the most import quality factors for really any piece of software is how easy it is to change.

"Good design means then when I make a change, it’s as if the entire program was crafted in anticipation of it. I can solve a task with just a few choice function calls that slot in perfectly, leaving not the slightest ripple on the placid surface of the code." source

Software changes as business requirements change. Regardless, of what drove the change, the solution that worked yesterday may not work today.

It's much easier to change clean, modular code, covered by tests, which is exactly the type of code that TDD tends to produce.

Let's look at an example.

Requirements

Say you have a client that wants you to develop a basic phone book for adding and displaying (in alphabetical order) phone numbers.

Should you create a list of numbers along with a few helper functions for appending, sorting, and printing? Or should you create a class? Sure, it probably doesn't matter so much at this point. You could start writing code to meet the current requirements. It feels like the most natural thing to do, right? However, what if those requirements change and you have to include searching or deleting? The code could quickly become messy if you don't decide on a smart strategy to begin with.

So take a step back and write some tests first.

Writing tests first

First, create (and activate) a virtual environment and install pytest:

(venv)$ pip install pytest

Create a new file to hold your tests called test_phone_book.py.

It seems reasonable to start with a class with two methods, add and all, right?

  • GIVEN a PhoneBook class with a records property
  • WHEN the all method is called
  • THEN all numbers should be returned in ascending order

The test should look something like this:

class TestPhoneBook:

    def test_all(self):

        phone_book = PhoneBook(
            records=[
                ('John Doe', '03 234 567 890'),
                ('Marry Doe', '01 234 567 890'),
                ('Donald Doe', '02 234 567 890'),
            ]
        )

        previous = ''

        for record in phone_book.all():
            assert record[0] > previous
            previous = record[0]

Here, we check that the previous element is always alphabetically less than the current element.

Run it:

(venv)$ pytest

Of course the test fails.

To implement, add a new file called phone_book.py:

class PhoneBook:

    def __init__(self, records=None):
        self.records = records or []

    def all(self):
        return sorted(self.records)

Import it in to the test file:

from phone_book import PhoneBook


class TestPhoneBook:

    def test_all(self):

        phone_book = PhoneBook(
            records=[
                ('John Doe', '03 234 567 890'),
                ('Marry Doe', '01 234 567 890'),
                ('Donald Doe', '02 234 567 890'),
            ]
        )

        previous = ''

        for record in phone_book.all():
            assert record[0] > previous
            previous = record[0]

Run it again:

(venv)$ pytest

The test passes now. You've met one of the first requirements.

Now write a test for an add method to check that a new number is in records.

  • GIVEN a PhoneBook with an add method
  • WHEN a number is added and the all method is called
  • THEN the new number is part of the returned numbers
from phone_book import PhoneBook


class TestPhoneBook:

    def test_all(self):

        phone_book = PhoneBook(
            records=[
                ('John Doe', '03 234 567 890'),
                ('Marry Doe', '01 234 567 890'),
                ('Donald Doe', '02 234 567 890'),
            ]
        )

        previous = ''

        for record in phone_book.all():
            assert record[0] > previous
            previous = record[0]

    def test_add(self):

        record = ('John Doe', '01 234 567 890')
        phone_book = PhoneBook(
            records=[
                ('Marry Doe', '01 234 567 890'),
                ('Donald Doe', '02 234 567 890'),
            ]
        )
        phone_book.add(record)

        assert record in phone_book.all()

The test should fail since the add method is not implemented yet.

class PhoneBook:

    def __init__(self, records=None):
        self.records = records or []

    def all(self):
        return sorted(self.records)

    def add(self, record):
        self.records.append(record)

The PhoneBook class now meets all of the above-mentioned requirements. Numbers can be added, and all of them can be returned sorted alphabetically. The customer is delighted. Package and deliver the code.

New requirements

Let's reflect on that first implementation.

Although we used tests to better define what should be done, we could have easily written code without them. In fact, the tests seemed to slow the process down.

A few weeks go by and you don't hear from the client. They must be enjoying adding and viewing phone numbers. Nice job. Give yourself a pat on the back, and send a gentle reminder to the client about that unpaid invoice. Not thirty seconds after you click send, you receive a frustrated email back, stating that the retrieving of numbers is quite slow.

What's going on? Well, you're sorting records every time the all method is called, which will slow over time. So, let's change the code to sort on list initialization and when a new number is added.

Since we focused on testing the interface rather than the underlying implementation, we can change the code without breaking the tests.

class PhoneBook:

    def __init__(self, records=None):
        self.records = sorted(records or [])

    def add(self, record):
        self.records.append(record)
        self.records = sorted(self.records)

    def all(self):
        return self.records

Tests should still pass.

That's great, but we can actually speed things up even more since the numbers are already sorted to begin with.

class PhoneBook:

    def __init__(self, records=None):
        self.records = sorted(records or [], key=lambda rec: rec[0])

    def add(self, record):

        index = len(self.records)
        for i in range(len(self.records)):
            if record[0] < self.records[i][0]:
                index = i
                break

        self.records.insert(index, record)

    def all(self):
        return self.records

Here, we insert the new number in order and remove the sorting.

Although we've changed the implementation to meet the new requirements, we still meet our initial ones. How do we know? Run the tests.

Can we do better?

We've met all of our requirements. That's great. Our client pays the invoice. All is well. Time passes. You forget about the project. Then, out of the blue, you see an email from them in your inbox complaining that the app is now slow when a new number is added.

You open your text editor and begin to investigate. Having forgotten about the project, you start with the tests and then dive into the code. Looking at the add method, you see you have to find the exact spot to insert a number before insertion to preserve the order. Both of these -- inserting and searching for an insertion index -- has time complexity O(n).

So, how do you improve performance there?

Turn to Google and Stack Overflow. Use your information retrieval skills. After an hour or so, you find that the time complexity to insert into a binary tree is O(log n). That's better. Besides that, elements can be returned in sorted order with in-order traversal. Therefore, go ahead and change your implementation to use a binary tree instead of a list.

Binary tree

New to binary trees? Check out the Binary Trees in Python: Introduction and Traversal Algorithms video as well as the excellent binarytree library,

First, define a node:

class Node:

    def __init__(self, data):

        self.left = None
        self.right = None
        self.data = data

Second, add an insert method:

class Node:

    def __init__(self, data):

        self.left = None
        self.right = None
        self.data = data

    def insert(self, data):
        # Compare the new value with the parent node
        if self.data:
            if data[0] < self.data[0]:
                if self.left is None:
                    self.left = Node(data)
                else:
                    self.left.insert(data)
            elif data[0] > self.data[0]:
                if self.right is None:
                    self.right = Node(data)
                else:
                    self.right.insert(data)
        else:
            self.data = data

Here, we check if there is data set at the current node.

If not, the data is set.

If data is set, it checks whether the first element is greater or less than the data we need to insert. Based on that, it adds the left or right node.

Finally, add the in-order traversal method:

class Node:

    def __init__(self, data):

        self.left = None
        self.right = None
        self.data = data

    def insert(self, data):
        # Compare the new value with the parent node
        if self.data:
            if data[0] < self.data[0]:
                if self.left is None:
                    self.left = Node(data)
                else:
                    self.left.insert(data)
            elif data[0] > self.data[0]:
                if self.right is None:
                    self.right = Node(data)
                else:
                    self.right.insert(data)
        else:
            self.data = data

    def inorder_traversal(self, root):
        res = []
        if root:
            res = self.inorder_traversal(root.left)
            res.append(root.data)
            res = res + self.inorder_traversal(root.right)
        return res

With that, we can implement it into our PhoneBook:

class Node:

    def __init__(self, data):

        self.left = None
        self.right = None
        self.data = data

    def insert(self, data):
        # Compare the new value with the parent node
        if self.data:
            if data[0] < self.data[0]:
                if self.left is None:
                    self.left = Node(data)
                else:
                    self.left.insert(data)
            elif data[0] > self.data[0]:
                if self.right is None:
                    self.right = Node(data)
                else:
                    self.right.insert(data)
        else:
            self.data = data

    def inorder_traversal(self, root):
        res = []
        if root:
            res = self.inorder_traversal(root.left)
            res.append(root.data)
            res = res + self.inorder_traversal(root.right)
        return res


class PhoneBook:

    def __init__(self, records=None):
        records = records or []

        if len(records) == 1:
            self.records = Node(records[0])
        elif len(records) > 1:
            self.records = Node(records[0])
            for elm in records[1:]:
                self.records.insert(elm)
        else:
            self.records = Node(None)

    def add(self, record):
        self.records.insert(record)

    def all(self):
        return self.records.inorder_traversal(self.records)

Run the tests. They should pass.

Conclusion

Writing tests first helps to define the problem well, which helps with writing a better solution.

You can use tests to help clarify the problem as well as a confusing feature's scope.

Tests then check if your solution solves the problem.

Most clients won't care about how you solve the problem as long as it works; thus, we focused our tests on the interface rather than the implementation. As we made changes to the code, we didn't need to change our tests since the problem the code solved didn't change.

As the complexity of the implementation increases, it may become necessary to add unit tests at that level as well. I recommend focusing your time and attention on integration tests at the implementation level and only adding unit tests when you find that your code is breaking repeatedly in a specific area.

Tests gave us a certain freedom in that we can change that implementation without having to worry about breaking the interface. After all, it's not important how it works but that it works.

TDD can provide the confidence needed to confidently refactor code for the better. Faster, cleaner, better structure -- it doesn't matter.

Happy coding!

Jan Giacomelli

Jan Giacomelli

Jan is a software engineer who lives in Ljubljana, Slovenia, Europe. He is a Staff Software Engineer at ren.co where he is leading backend engineering efforts. He loves Python, FastAPI, and Test-Driven Development. When he's not writing code, deploying to AWS, or speaking at a conference, he's probably skiing, windsurfing, or playing guitar. Currently, he's working on his new course Complete Python Testing Guide.

Share this tutorial

Featured Course

Developing a Real-Time Taxi App with Django Channels and Angular

Learn how to create a ride-sharing app with Django Channels, Angular, and Docker. Along the way, you'll learn how to manage client/server communication with Django Channels, control flow and routing with Angular, and build a RESTful API with Django REST Framework.

Featured Course

Developing a Real-Time Taxi App with Django Channels and Angular

Learn how to create a ride-sharing app with Django Channels, Angular, and Docker. Along the way, you'll learn how to manage client/server communication with Django Channels, control flow and routing with Angular, and build a RESTful API with Django REST Framework.