WebSockets

Part 1, Chapter 6


Up until now, we have dealt with users in a generic way: Users can authenticate and they can retrieve trips. The following section separates users into distinct roles, and this is where things get interesting. Fundamentally, users can participate in trips in one of two ways—they either drive the cars or they ride in them. A rider initiates the trip with a request, which is broadcasted to all available drivers. A driver starts a trip by accepting the request. At this point, the driver heads to the pick-up address. The rider is instantly alerted that a driver has started the trip and other drivers are notified that the trip is no longer up for grabs.

Instantaneous communication between the driver and the rider is vital here, and we can achieve it using WebSockets via Django Channels.

Django Channels Setup

Test

Create a new trips/tests/test_websocket.py file and add the following test:

# tests/test_websocket.py

from django.contrib.auth import get_user_model
from django.contrib.auth.models import Group
from django.test import Client
from channels.db import database_sync_to_async
from channels.layers import get_channel_layer
from channels.testing import WebsocketCommunicator
import pytest

from taxi.routing import application
from trips.models import Trip


TEST_CHANNEL_LAYERS = {
    'default': {
        'BACKEND': 'channels.layers.InMemoryChannelLayer',
    },
}


@database_sync_to_async
def create_user(
    *,
    username='[email protected]',
    password='pAssw0rd!',
    group='rider'
):
    # Create user.
    user = get_user_model().objects.create_user(
        username=username,
        password=password
    )

    # Create user group.
    user_group, _ = Group.objects.get_or_create(name=group)
    user.groups.add(user_group)
    user.save()
    return user


@pytest.mark.asyncio
@pytest.mark.django_db(transaction=True)
class TestWebsockets:

    async def test_authorized_user_can_connect(self, settings):
        # Use in-memory channel layers for testing.
        settings.CHANNEL_LAYERS = TEST_CHANNEL_LAYERS

        # Force authentication to get session ID.
        client = Client()
        user = await create_user()
        client.force_login(user=user)

        # Pass session ID in headers to authenticate.
        communicator = WebsocketCommunicator(
            application=application,
            path='/taxi/',
            headers=[(
                b'cookie',
                f'sessionid={client.cookies["sessionid"].value}'.encode('ascii')
            )]
        )
        connected, _ = await communicator.connect()
        assert connected is True
        await communicator.disconnect()

There's a lot going on here. The first thing you'll notice is that we're using pytest instead of the built-in Django testing tools. We're also using coroutines that were introduced with the asyncio module in Python 3.4. Django Channels 2.x mandates the use of both pytest and asyncio.

If you're not familiar with asynchronous programming, we'd strongly encourage you to learn the basics, starting with the official asyncio Python documentation and the excellent Python & Async Simplified guide (by the maintainer of Django Channels, Andrew Godwin).

Remember how we created HTTP test classes by extending APITestCase? Grouping multiple tests with pytest only requires you to write a basic class. We named ours TestWebsockets. We also decorated the class with two marks, which sets metadata on each of the test methods contained within. The @pytest.mark.asyncio mark tells pytest to treat tests as asyncio coroutines. The @pytest.mark.django_db mark allows tests to access the Django database. Specifying transaction=True ensures that the database will be flushed between tests.

Let's look at the create_user() function next. Accessing the Django database is a synchronous operation as opposed to an asynchronous one, which means you need to handle it in a special way to ensure that the connections are closed properly. All functions that access the Django ORM should be decorated with @database_sync_to_async.

Finally, we come to the actual test. First, pay attention to the fact that we included a TEST_CHANNEL_LAYERS constant at the top of the file after the imports. We used that constant in the first line of our test along with the settings fixture provided by pytest-django. This line of code effectively overwrites the application's settings to use the InMemoryChannelLayer instead of the configured RedisChannelLayer. Doing this allows us to focus our tests on the behavior we are programming rather than the implementation with Redis. Rest assured that when we run our server in a non-testing environment, Redis will be used.

We went through a lot of trouble setting up authentication in earlier chapters of this course. Requests over WebSockets use authentication too. In the browser, every WebSockets request sends cookies (including our sessionid) to the server. Remember, the sessionid cookie is saved in our browser after a successful login.

drf cookie

We have to handle this behavior explicitly in our test by creating an instance of Client() and forcing a login with the authentication backend. Then we can extract the sessionid cookie from the Client and add it to our cookie header in our WebSockets request. We send this request using WebsocketCommunicator, which is essentially the Channels counterpart to Django's Client.

Add a pytest.ini file to the outermost directory. Your directory structure should now look like this:

.
├── pytest.ini
└── server
    └── taxi
        ├── db.sqlite3
        ├── manage.py
        ├── taxi
        │   ├── __init__.py
        │   ├── asgi.py
        │   ├── routing.py
        │   ├── settings.py
        │   ├── urls.py
        │   └── wsgi.py
        └── trips
            ├── __init__.py
            ├── admin.py
            ├── apps.py
            ├── migrations
            │   ├── 0001_initial.py
            │   ├── 0002_trip.py
            │   └── __init__.py
            ├── models.py
            ├── serializers.py
            ├── tests
            │   ├── __init__.py
            │   ├── test_http.py
            │   └── test_websocket.py
            ├── urls.py
            └── views.py

Then add the following three lines to pytest.ini.

[pytest]
DJANGO_SETTINGS_MODULE = taxi.settings
python_files = test_websocket.py

From the "server/taxi" directory, run the pytest tests and watch them fail.

(env)$ pytest

You should see:

ValueError: No application configured for scope type 'websocket'

Consumer

We need to add a consumer, Channel's version of a Django view. Create a new trips/consumers.py file with the following code.

# trips/consumers.py

from channels.generic.websocket import AsyncJsonWebsocketConsumer


class TaxiConsumer(AsyncJsonWebsocketConsumer):

    async def connect(self):
        user = self.scope['user']
        if user.is_anonymous:
            await self.close()
        else:
            await self.accept()

We access the user from the scope like we would from a traditional Django request. Funny how closely all of these features match up. Our connect() method accepts the connection if the user is authenticated and rejects it otherwise.

Now we need to update our routing.py file to get our tests to pass.

# taxi/routing.py

from django.urls import path  # new
from channels.auth import AuthMiddlewareStack  # new
from channels.routing import ProtocolTypeRouter, URLRouter  # changed

from trips.consumers import TaxiConsumer

# changed
application = ProtocolTypeRouter({
    'websocket': AuthMiddlewareStack(
        URLRouter([
            path('taxi/', TaxiConsumer),
        ])
    )
})

This setup declares that all WebSockets requests should be passed through an AuthMiddlewareStack, which processes cookies and handles session authentication. We also define routes with URLRouter in a way that is reminiscent of the Django urlconf.

Run the tests again to see them pass.

(env)$ pytest

Refactoring

We're going to have to create a user, authenticate it, and pass it in the request as part of every test from this point forward. Let's refactor our code to capture that behavior as part of each test's setup.

Add the following code to the bottom of our trips/tests/test_websocket.py file:

# tests/test_websocket.py

async def auth_connect(user):
    # Force authentication to get session ID.
    client = Client()
    client.force_login(user=user)

    # Pass session ID in headers to authenticate.
    communicator = WebsocketCommunicator(
        application=application,
        path='/taxi/',
        headers=[(
            b'cookie',
            f'sessionid={client.cookies["sessionid"].value}'.encode('ascii')
        )]
    )
    connected, _ = await communicator.connect()
    assert connected is True
    return communicator

Then update the existing test to use it.

async def test_authorized_user_can_connect(self, settings):
    # Use in-memory channel layers for testing.
    settings.CHANNEL_LAYERS = TEST_CHANNEL_LAYERS

    user = await create_user(
        username='[email protected]',
        group='rider'
    )
    communicator = await auth_connect(user)
    await communicator.disconnect()

Run the pytest tests again. They still pass. Awesome.

Create Trips

Test

Next, we're going to be handling the functionality that allows riders to create trips and drivers to update them. Add the following new test to TestWebsockets in trips/tests/test_websocket.py:

# tests/test_websocket.py

async def test_rider_can_create_trips(self, settings):
    settings.CHANNEL_LAYERS = TEST_CHANNEL_LAYERS

    user = await create_user(
        username='[email protected]',
        group='rider'
    )
    communicator = await auth_connect(user)

    # Send JSON message to server.
    await communicator.send_json_to({
        'type': 'create.trip',
        'data': {
            'pick_up_address': 'A',
            'drop_off_address': 'B',
            'rider': user.id,
        }
    })

    # Receive JSON message from server.
    response = await communicator.receive_json_from()
    data = response.get('data')

    # Confirm data.
    assert data['id'] is not None
    assert 'A' == data['pick_up_address']
    assert 'B' == data['drop_off_address']
    assert Trip.REQUESTED == data['status']
    assert data['driver'] is None
    assert user.username == data['rider'].get('username')

    await communicator.disconnect()

After this test establishes an authenticated WebSockets connection, it sends a JSON-encoded message to the server, which will then create a new Trip and will return it to the client in a response. All messages should include a type. Also, remember to disconnect from the server at the end of every test.

Run the failing test.

Model

Add the following fields to the Trip model in trips/models.py:

# trips/models.py

class Trip(models.Model):
    REQUESTED = 'REQUESTED'
    STARTED = 'STARTED'
    IN_PROGRESS = 'IN_PROGRESS'
    COMPLETED = 'COMPLETED'
    STATUSES = (
        (REQUESTED, REQUESTED),
        (STARTED, STARTED),
        (IN_PROGRESS, IN_PROGRESS),
        (COMPLETED, COMPLETED),
    )

    id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
    created = models.DateTimeField(auto_now_add=True)
    updated = models.DateTimeField(auto_now=True)
    pick_up_address = models.CharField(max_length=255)
    drop_off_address = models.CharField(max_length=255)
    status = models.CharField(max_length=20, choices=STATUSES, default=REQUESTED)
    driver = models.ForeignKey( # new
        settings.AUTH_USER_MODEL,
        null=True,
        blank=True,
        on_delete=models.DO_NOTHING,
        related_name='trips_as_driver'
    )
    rider = models.ForeignKey( # new
        settings.AUTH_USER_MODEL,
        null=True,
        blank=True,
        on_delete=models.DO_NOTHING,
        related_name='trips_as_rider'
    )

    def __str__(self):
        return f'{self.id}'

    def get_absolute_url(self):
        return reverse('trip:trip_detail', kwargs={'trip_id': self.id})

Add the import:

from django.conf import settings

We expanded our existing Trip model, in order to link a driver and a rider to a trip. Remember: Drivers and riders are just normal users that belong to different user groups. Later on, we'll see how the same app can serve both types of users and give each a unique experience.

Make and run migrations to update our Trip model database table:

(env)$ python manage.py makemigrations trips --name trip_driver_rider
(env)$ python manage.py migrate

Let's update our admin page to reflect the changes we made in our model.

# trips/admin.py

from django.contrib import admin
from django.contrib.auth.admin import UserAdmin as DefaultUserAdmin

from .models import Trip, User


@admin.register(User)
class UserAdmin(DefaultUserAdmin):
    pass


@admin.register(Trip)
class TripAdmin(admin.ModelAdmin):
    fields = ( # changed
        'id', 'pick_up_address', 'drop_off_address', 'status',
        'driver', 'rider',
        'created', 'updated',
    )
    list_display = ( # changed
        'id', 'pick_up_address', 'drop_off_address', 'status',
        'driver', 'rider',
        'created', 'updated',
    )
    list_filter = (
        'status',
    )
    readonly_fields = (
        'id', 'created', 'updated',
    )

Serializer

By default, our TripSerializer processes related models as primary keys. That is the exact behavior that we want when we use a serializer to create a database record. On the other hand, when we get the serialized Trip data back from the server, we want more information about the rider and the driver than just their database IDs.

Create a new ReadOnlyTripSerializer after our existing TripSerializer. The difference is that the ReadOnlyTripSerializer serializes the full User object instead of its primary key.

class ReadOnlyTripSerializer(serializers.ModelSerializer):
    driver = UserSerializer(read_only=True)
    rider = UserSerializer(read_only=True)

    class Meta:
        model = Trip
        fields = '__all__'

Consumer

Update the trips/consumers.py file:

# trips/consumers.py

from channels.db import database_sync_to_async # new
from channels.generic.websocket import AsyncJsonWebsocketConsumer

from trips.serializers import ReadOnlyTripSerializer, TripSerializer # new


class TaxiConsumer(AsyncJsonWebsocketConsumer):

    async def connect(self):
        user = self.scope['user']
        if user.is_anonymous:
            await self.close()
        else:
            await self.accept()

    # new
    async def receive_json(self, content, **kwargs):
        message_type = content.get('type')
        if message_type == 'create.trip':
            await self.create_trip(content)

    # new
    async def create_trip(self, event):
        trip = await self._create_trip(event.get('data'))
        trip_data = ReadOnlyTripSerializer(trip).data
        await self.send_json({
            'type': 'create.trip',
            'data': trip_data
        })

    # new
    @database_sync_to_async
    def _create_trip(self, content):
        serializer = TripSerializer(data=content)
        serializer.is_valid(raise_exception=True)
        trip = serializer.create(serializer.validated_data)
        return trip

All incoming messages are received by the receive_json() method in the consumer. Here is where you should delegate the business logic to process different message types. Our create_trip() method creates a new trip and passes the details back to the client. Note that we are using a special decorated _create_trip() helper method to do the actual database update.

Ensure the test now passes.

Refactoring

Let's do another round of refactoring to avoid duplicating code. Add the following helper to the bottom of the trips/tests/test_websocket.py file:

# tests/test_websocket.py

async def connect_and_create_trip(
    *,
    user,
    pick_up_address='A',
    drop_off_address='B'
):
    communicator = await auth_connect(user)
    await communicator.send_json_to({
        'type': 'create.trip',
        'data': {
            'pick_up_address': pick_up_address,
            'drop_off_address': drop_off_address,
            'rider': user.id,
        }
    })
    return communicator

Update the test we just created too.

async def test_rider_can_create_trips(self, settings):
    settings.CHANNEL_LAYERS = TEST_CHANNEL_LAYERS

    user = await create_user(
        username='[email protected]',
        group='rider'
    )
    communicator = await connect_and_create_trip(user=user)

    # Receive JSON message from server.
    response = await communicator.receive_json_from()
    data = response.get('data')

    # Confirm data.
    assert data['id'] is not None
    assert 'A' == data['pick_up_address']
    assert 'B' == data['drop_off_address']
    assert Trip.REQUESTED == data['status']
    assert data['driver'] is None
    assert user.username == data['rider'].get('username')

    await communicator.disconnect()

Ensure the test still passes as expected.

Another Test

When a rider creates a trip, he should be automatically registered to receive updates about that trip, so that whenever a driver updates the trip, the rider will receive a notification.

Add a new test to trips/tests/test_websocket.py:

# tests/test_wesockets.py

async def test_rider_is_added_to_trip_group_on_create(self, settings):
    settings.CHANNEL_LAYERS = TEST_CHANNEL_LAYERS

    user = await create_user(
        username='[email protected]',
        group='rider'
    )

    # Connect and send JSON message to server.
    communicator = await connect_and_create_trip(user=user)

    # Receive JSON message from server.
    # Rider should be added to new trip's group.
    response = await communicator.receive_json_from()
    data = response.get('data')

    trip_id = data['id']
    message = {
        'type': 'echo.message',
        'data': 'This is a test message.'
    }

    # Send JSON message to new trip's group.
    channel_layer = get_channel_layer()
    await channel_layer.group_send(trip_id, message=message)

    # Receive JSON message from server.
    response = await communicator.receive_json_from()

    # Confirm data.
    assert message == response

    await communicator.disconnect()

This test should prove that once a rider has created a trip, he gets added to a group to receive updates about it. Our test accesses that group and then sends a message to it via the group_send() method.

Again, ensure the new test fails.

Consumer

# trips/consumers.py

import asyncio # new

from channels.db import database_sync_to_async
from channels.generic.websocket import AsyncJsonWebsocketConsumer

from trips.serializers import ReadOnlyTripSerializer, TripSerializer


class TaxiConsumer(AsyncJsonWebsocketConsumer):

    # new
    def __init__(self, scope):
        super().__init__(scope)

        # Keep track of the user's trips.
        self.trips = set()

    async def connect(self): ...

    async def receive_json(self, content, **kwargs): ...

    # new
    async def echo_message(self, event):
        await self.send_json(event)

    # changed
    async def create_trip(self, event):
        trip = await self._create_trip(event.get('data'))
        trip_id = f'{trip.id}'
        trip_data = ReadOnlyTripSerializer(trip).data

        # Add trip to set.
        self.trips.add(trip_id)

        # Add this channel to the new trip's group.
        await self.channel_layer.group_add(
            group=trip_id,
            channel=self.channel_name
        )

        await self.send_json({
            'type': 'create.trip',
            'data': trip_data
        })

    # new
    async def disconnect(self, code):
        # Remove this channel from every trip's group.
        channel_groups = [
            self.channel_layer.group_discard(
                group=trip,
                channel=self.channel_name
            )
            for trip in self.trips
        ]
        asyncio.gather(*channel_groups)

        # Remove all references to trips.
        self.trips.clear()

        await super().disconnect(code)

    @database_sync_to_async
    def _create_trip(self, content): ...

Remember how we insisted on including a type in every message? Channels handles messages sent to groups (over channel layers) differently than it handles messages sent to the server directly by the client. Channels replaces all . in the message type with _ and searches the consumer for a method name that matches. In this case, echo.message will access the echo_message() method.

Some other things to note:

  • We initialize a trips list and keep track of the rider's trips during the life of the request. (We could also do this on the user's session if we wanted to.)
  • When we create a trip, we add it to our tracked list. We also add the user to a group identified by the new trip's natural key value.
  • We explicitly remove the user from each group when he disconnects. The asyncio.gather() method executes a list of asynchronous commands.

The test should pass.

Accessing Persistent Trip Data

Test

We're successfully tracking a rider's trips as long as his session is alive. But what happens when that rider closes the app and then opens it again? We need to re-establish the connections to his groups.

Let's create a new test. Start by creating a new function to create a trip in the database.

trips/tests/test_websocket.py

# tests/test_websocket.py

@database_sync_to_async
def create_trip(**kwargs):
    return Trip.objects.create(**kwargs)

Add the test.

# tests/test_websocket.py

async def test_rider_is_added_to_trip_groups_on_connect(self, settings):
    settings.CHANNEL_LAYERS = TEST_CHANNEL_LAYERS

    user = await create_user(
        username='[email protected]',
        group='rider'
    )

    # Create trips and link to rider.
    trip = await create_trip(
        pick_up_address='A',
        drop_off_address='B',
        rider=user
    )

    # Connect to server.
    # Trips for rider should be retrieved.
    # Rider should be added to trips' groups.
    communicator = await auth_connect(user)

    message = {
        'type': 'echo.message',
        'data': 'This is a test message.'
    }

    channel_layer = get_channel_layer()

    # Test sending JSON message to trip group.
    await channel_layer.group_send(f'{trip.id}', message=message)
    response = await communicator.receive_json_from()
    assert message == response

    await communicator.disconnect()

Consumer

Update the consumer again.

# trips/consumers.py

import asyncio

from channels.db import database_sync_to_async
from channels.generic.websocket import AsyncJsonWebsocketConsumer

from trips.models import Trip # new
from trips.serializers import ReadOnlyTripSerializer, TripSerializer


class TaxiConsumer(AsyncJsonWebSocketConsumer):

    def __init__(self, scope): ...

    # changed
    async def connect(self):
        user = self.scope['user']
        if user.is_anonymous:
            await self.close()
        else:
            # Get trips and add rider to each one's group.
            channel_groups = []
            self.trips = set([
                str(trip_id) for trip_id in await self._get_trips(self.scope['user'])
            ])
            for trip in self.trips:
                channel_groups.append(self.channel_layer.group_add(trip, self.channel_name))
            asyncio.gather(*channel_groups)
            await self.accept()

    async def receive_json(self, content, **kwargs): ...

    async def echo_message(self, event): ...

    async def create_trip(self, event): ...

    async def disconnect(self, code): ...

    @database_sync_to_async
    def _create_trip(self, content): ...

    # new
    @database_sync_to_async
    def _get_trips(self, user):
        if not user.is_authenticated:
            raise Exception('User is not authenticated.')
        user_groups = user.groups.values_list('name', flat=True)
        if 'driver' in user_groups:
            return user.trips_as_driver.exclude(
                status=Trip.COMPLETED
            ).only('id').values_list('id', flat=True)
        else:
            return user.trips_as_rider.exclude(
                status=Trip.COMPLETED
            ).only('id').values_list('id', flat=True)

We expanded the connect() method to retrieve the rider's trips and add the rider to each one's associated group. The _get_trips() method queries the database and needs to be decorated appropriately.

Modify the create_trip() method slightly.

# trips/consumers.py

async def create_trip(self, event):
    trip = await self._create_trip(event.get('data'))
    trip_id = f'{trip.id}'
    trip_data = ReadOnlyTripSerializer(trip).data

    # Handle add only if trip is not being tracked.
    if trip_id not in self.trips:
        self.trips.add(trip_id)
        await self.channel_layer.group_add(
            group=trip_id,
            channel=self.channel_name
        )

    await self.send_json({
        'type': 'create.trip',
        'data': trip_data
    })

We don't want to add the rider to the same trip twice.

(env)$ pytest
======================================== test session starts ========================================
platform darwin -- Python 3.7.5, pytest-5.2.2, py-1.8.0, pluggy-0.13.0
Django settings: taxi.settings (from ini file)
rootdir: /Users/michael.herman/repos/testdriven/taxi-app, inifile: pytest.ini
plugins: django-3.6.0, asyncio-0.10.0
collected 4 items

trips/tests/test_websocket.py ....

Update Trips

Test

Let's handle the functionality to update existing trips. Start with a test. We can already anticipate needing to reuse updating behavior, so we can avoid refactoring by adding a proper service function now.

# trips/tests/test_websocket.py

async def connect_and_update_trip(*, user, trip, status):
    communicator = await auth_connect(user)
    await communicator.send_json_to({
        'type': 'update.trip',
        'data': {
            'id': f'{trip.id}',
            'pick_up_address': trip.pick_up_address,
            'drop_off_address': trip.drop_off_address,
            'status': status,
            'driver': user.id,
        }
    })
    return communicator

Here's the first test.

async def test_driver_can_update_trips(self, settings):
    settings.CHANNEL_LAYERS = TEST_CHANNEL_LAYERS

    trip = await create_trip(
        pick_up_address='A',
        drop_off_address='B'
    )
    user = await create_user(
        username='[email protected]',
        group='driver'
    )

    # Send JSON message to server.
    communicator = await connect_and_update_trip(
        user=user,
        trip=trip,
        status=Trip.IN_PROGRESS
    )

    # Receive JSON message from server.
    response = await communicator.receive_json_from()
    data = response.get('data')

    # Confirm data.
    assert str(trip.id) == data['id']
    assert 'A' == data['pick_up_address']
    assert 'B' == data['drop_off_address']
    assert Trip.IN_PROGRESS == data['status']
    assert user.username == data['driver'].get('username')
    assert data['rider'] is None

    await communicator.disconnect()

In this test, we are explicitly updating an existing trip's status from REQUESTED to IN_PROGRESS. We send the request, the server updates the trip, we confirm that the response data matches our expectations.

Consumer

Edit the consumer to handle updates.

# trips/consumers.py

import asyncio

from channels.db import database_sync_to_async
from channels.generic.websocket import AsyncJsonWebsocketConsumer

from trips.models import Trip
from trips.serializers import ReadOnlyTripSerializer, TripSerializer


class TaxiConsumer(AsyncJsonWebsocketConsumer):

    def __init__(self, scope): ...

    async def connect(self): ...

    async def receive_json(self, content, **kwargs):
        message_type = content.get('type')
        if message_type == 'create.trip':
            await self.create_trip(content)
        elif message_type == 'update.trip':  # new
            await self.update_trip(content)

    async def echo_message(self, event): ...

    async def create_trip(self, event): ...

    # new
    async def update_trip(self, event):
        trip = await self._update_trip(event.get('data'))
        trip_id = f'{trip.id}'
        trip_data = ReadOnlyTripSerializer(trip).data

        # Handle add only if trip is not being tracked.
        # This happens when a driver accepts a request.
        if trip_id not in self.trips:
            self.trips.add(trip_id)
            await self.channel_layer.group_add(
                group=trip_id,
                channel=self.channel_name
            )

        await self.send_json({
            'type': 'update.trip',
            'data': trip_data
        })

    async def disconnect(self, code): ...

    @database_sync_to_async
    def _create_trip(self, content): ...

    @database_sync_to_async
    def _get_trips(self, user): ...

    # new
    @database_sync_to_async
    def _update_trip(self, content):
        instance = Trip.objects.get(id=content.get('id'))
        serializer = TripSerializer(data=content)
        serializer.is_valid(raise_exception=True)
        trip = serializer.update(instance, serializer.validated_data)
        return trip

We have a new event type—update.trip. We created corresponding update_trip() and _update_trip() methods to process the event. Updating the trip adds the driver to the associated trip group.

Make sure the test passes.

Another Test

Add one more test for the driver.

async def test_driver_is_added_to_trip_group_on_update(self, settings):
    settings.CHANNEL_LAYERS = TEST_CHANNEL_LAYERS

    trip = await create_trip(
        pick_up_address='A',
        drop_off_address='B'
    )
    user = await create_user(
        username='[email protected]',
        group='driver'
    )

    # Send JSON message to server.
    communicator = await connect_and_update_trip(
        user=user,
        trip=trip,
        status=Trip.IN_PROGRESS
    )

    # Receive JSON message from server.
    response = await communicator.receive_json_from()
    data = response.get('data')

    trip_id = data['id']
    message = {
        'type': 'echo.message',
        'data': 'This is a test message.'
    }

    # Send JSON message to trip's group.
    channel_layer = get_channel_layer()
    await channel_layer.group_send(trip_id, message=message)

    # Receive JSON message from server.
    response = await communicator.receive_json_from()

    # Confirm data.
    assert message == response

    await communicator.disconnect()

The driver should receive a notification of any updates that occur on the trip. This test does not require any changes to the consumer.

All drivers should be alerted whenever a new trip is created. Riders should be alerted when a driver updates the trip that they created. Add tests to capture that behavior.

async def test_driver_is_alerted_on_trip_create(self, settings):
    settings.CHANNEL_LAYERS = TEST_CHANNEL_LAYERS

    # Listen to the 'drivers' group test channel.
    channel_layer = get_channel_layer()
    await channel_layer.group_add(
        group='drivers',
        channel='test_channel'
    )

    user = await create_user(
        username='[email protected]',
        group='rider'
    )

    # Send JSON message to server.
    communicator = await connect_and_create_trip(user=user)

    # Receive JSON message from server on test channel.
    response = await channel_layer.receive('test_channel')
    data = response.get('data')

    # Confirm data.
    assert data['id'] is not None
    assert user.username == data['rider'].get('username')

    await communicator.disconnect()

async def test_rider_is_alerted_on_trip_update(self, settings):
    settings.CHANNEL_LAYERS = TEST_CHANNEL_LAYERS

    trip = await create_trip(
        pick_up_address='A',
        drop_off_address='B'
    )

    # Listen to the trip group test channel.
    channel_layer = get_channel_layer()
    await channel_layer.group_add(
        group=f'{trip.id}',
        channel='test_channel'
    )

    user = await create_user(
        username='[email protected]',
        group='driver'
    )

    # Send JSON message to server.
    communicator = await connect_and_update_trip(
        user=user,
        trip=trip,
        status=Trip.IN_PROGRESS
    )

    # Receive JSON message from server on test channel.
    response = await channel_layer.receive('test_channel')
    data = response.get('data')

    # Confirm data.
    assert f'{trip.id}' == data['id']
    assert user.username == data['driver'].get('username')

    await communicator.disconnect()

When run, you should see the tests hang.

Consumer

We need to update some of our consumer's methods to accommodate the driver. First, we check whether the user is a driver when the WebSocket connection is established. If he is, then we subscribe him to the drivers group. Next, we alert all drivers when a rider broadcasts a trip request.

When a driver accepts a request, we need to alert the corresponding rider every time the driver updates the trip's status. For example, the rider should get a message like "your ride is on its way". Lastly, a driver should be removed from the drivers group when he closes his WebSocket connection by exiting the app.

# trips/consumers.py

import asyncio

from channels.db import database_sync_to_async
from channels.generic.websocket import AsyncJsonWebsocketConsumer

from trips.models import Trip
from trips.serializers import ReadOnlyTripSerializer, TripSerializer


class TaxiConsumer(AsyncJsonWebsocketConsumer):

    def __init__(self, scope): ...

    # changed
    async def connect(self):
        user = self.scope['user']
        if user.is_anonymous:
            await self.close()
        else:
            channel_groups = []

            # Add a driver to the 'drivers' group.
            user_group = await self._get_user_group(self.scope['user'])
            if user_group == 'driver':
                channel_groups.append(self.channel_layer.group_add(
                    group='drivers',
                    channel=self.channel_name
                ))

            self.trips = set([
                str(trip_id) for trip_id in await self._get_trips(self.scope['user'])
            ])
            for trip in self.trips:
                channel_groups.append(self.channel_layer.group_add(trip, self.channel_name))
            asyncio.gather(*channel_groups)

            await self.accept()

    async def receive_json(self, content, **kwargs): ...

    async def echo_message(self, event): ...

    # changed
    async def create_trip(self, event):
        trip = await self._create_trip(event.get('data'))
        trip_id = f'{trip.id}'
        trip_data = ReadOnlyTripSerializer(trip).data

        # Send rider requests to all drivers.
        await self.channel_layer.group_send(group='drivers', message={
            'type': 'echo.message',
            'data': trip_data
        })

        if trip_id not in self.trips:
            self.trips.add(trip_id)
            await self.channel_layer.group_add(
                group=trip_id,
                channel=self.channel_name
            )

        await self.send_json({
            'type': 'create.trip',
            'data': trip_data
        })

    # changed
    async def update_trip(self, event):
        trip = await self._update_trip(event.get('data'))
        trip_id = f'{trip.id}'
        trip_data = ReadOnlyTripSerializer(trip).data

        # Send updates to riders that subscribe to this trip.
        await self.channel_layer.group_send(group=trip_id, message={
            'type': 'echo.message',
            'data': trip_data
        })

        if trip_id not in self.trips:
            self.trips.add(trip_id)
            await self.channel_layer.group_add(
                group=trip_id,
                channel=self.channel_name
            )

        await self.send_json({
            'type': 'update.trip',
            'data': trip_data
        })

    # changed
    async def disconnect(self, code):
        channel_groups = [
            self.channel_layer.group_discard(
                group=trip,
                channel=self.channel_name
            )
            for trip in self.trips
        ]

        # Discard driver from 'drivers' group.
        user_group = await self._get_user_group(self.scope['user'])
        if user_group == 'driver':
            channel_groups.append(self.channel_layer.group_discard(
                group='drivers',
                channel=self.channel_name
            ))

        asyncio.gather(*channel_groups)
        self.trips.clear()

        await super().disconnect(code)

    @database_sync_to_async
    def _create_trip(self, content): ...

    @database_sync_to_async
    def _get_trips(self, user): ...

    # new
    @database_sync_to_async
    def _get_user_group(self, user):
        if not user.is_authenticated:
            raise Exception('User is not authenticated.')
        return user.groups.first().name

    @database_sync_to_async
    def _update_trip(self, content): ...

One last test run.

(env)$ pytest
======================================== test session starts ========================================
platform darwin -- Python 3.7.5, pytest-5.2.2, py-1.8.0, pluggy-0.13.0
Django settings: taxi.settings (from ini file)
rootdir: /Users/michael.herman/repos/testdriven/taxi-app, inifile: pytest.ini
plugins: django-3.6.0, asyncio-0.10.0
collected 8 items

trips/tests/test_websocket.py ........



Mark as Completed