Adventures with \Device\Afd - test driven design

Page content

I’ve been investigating the ‘sparsely documented’ \Device\Afd interface that lies below the Winsock2 layer. Today I use a test-driven method for building some code to make using this API a little easier.

Using the API

In my previous posts on the \Device\Afd interface I focused on exploring and understanding the API itself. Since it’s ‘sparsely documented’ I used unit tests to explore and record my findings. This has left me with some code that is easy to come back to and pick up, which is lucky since it’s been quite a while since I last had some time to tinker with it.

The next step is to take what I learnt last time and start to build code that could actually be used for socket communication. To that end, I’ve added a “socket” project to the code on GitHub. See the last article for an explanation as to how I’m using GoogleTest and how the code is structured.

Code

Full source can be found here on GitHub. We’re now working with the socket project.

This isn’t production code, error handling is simply “panic and run away”.

This code is licensed with the MIT license.

A TCP socket class

The first thing I’m going to do is write a simple tcp_socket class that wraps up the work we need to do that is socket and connection related. This will separate the code that deals with the \Device\Afd API and make it easier to test and reason about. The end goal here is to have a simple socket that can be event driven and which will allow me to write a simple tcp client.

class tcp_socket : private afd_events
{
   public:

      explicit tcp_socket(
         afd_handle afd,
         tcp_socket_callbacks &callbacks);

      ~tcp_socket() override;

      void connect(
         const sockaddr &address,
         int address_length);

      int write(
         const BYTE *pData,
         int data_length);

      int read(
         BYTE *pBuffer,
         int buffer_length);

      void close();

      enum class shutdown_how
      {
         receive  = 0x00,
         send     = 0x01,
         both     = 0x02
      };

      void shutdown(
         shutdown_how how);

   private :

      ULONG handle_events(
         ULONG eventsToHandle,
         NTSTATUS status) override;

      const afd_handle afd;

      SOCKET s;

      ULONG events;

      tcp_socket_callbacks &callbacks;

      enum class state
      {
         created,
         pending_connect,
         connected,
         disconnected
      };

      state connection_state;
};

The tcp_socket object interacts with the \Device\Afd API via the afd_events interface, which allows the AFD code to call into the socket and get it to deal with events; and the afd_handle object which is our connection to the AFD API. The socket object reports events to the code that uses it via a callback interface, tcp_socket_callbacks. This allows for the socket to be event-driven and asynchronous. The connect() method, for example, will return immediately and report success/failure via a method on the tcp_socket_callbacks interface. The callbacks could come from any thread, so we’ll eventually have to take into account locking and thread-safety, but for now we’ll sketch out the design of the code using tests and a single thread.

The first test

As is common with Test Driven Development, the first test is the most important and often exposes most of the code-under-test’s dependencies. Here’s our first test…

TEST(AFDSocket, TestConstruct)
{
   const auto handles = CreateAfdAndIOCP();

   afd_system afd(handles.afd);

   afd_handle handle(afd, 0);

   mock_tcp_socket_callbacks callbacks;

   tcp_socket socket(handle, callbacks);
}

This shows that I’ve already thought about a quick and dirty way to separate the AFD code from the socket code. I use an afd_system which wraps the API. This afd object is then accessed via an afd_handle which is an object that wraps a handle to the internals of the AFD object and, in this case allows easy manipulation of the afd object for a specific socket, using an index into its internal data structures. Note that this kind of stuff is likely to change. I’m just “sketching in code” here, and the important thing is not that what I end up with is “right” just that it helps me move towards something that will eventually be “right”.

I’m using GoogleMock to mock up the socket’s callback interface as this is easy to do and doesn’t require much code.

class mock_tcp_socket_callbacks : public tcp_socket_callbacks
{
   public :

   MOCK_METHOD(void, on_connected, (tcp_socket &), (override));
   MOCK_METHOD(void, on_connection_failed, (tcp_socket &, DWORD), (override));
   MOCK_METHOD(void, on_readable, (tcp_socket &), (override));
   MOCK_METHOD(void, on_readable_oob, (tcp_socket &), (override));
   MOCK_METHOD(void, on_writable, (tcp_socket &), (override));
   MOCK_METHOD(void, on_client_close, (tcp_socket &), (override));
   MOCK_METHOD(void, on_connection_reset, (tcp_socket &), (override));
   MOCK_METHOD(void, on_disconnected, (tcp_socket &), (override));
};

From this point on, the tests that I write are very similar to the tests that I used when I was working on the “understand” project, and they’re written in a similar way.

Here’s a test that shows what happens when we try and connect to an endpoint that doesn’t exist.

TEST(AFDSocket, TestConnectFail)
{
   const auto handles = CreateAfdAndIOCP();

   afd_system afd(handles.afd);

   afd_handle handle(afd, 0);

   mock_tcp_socket_callbacks callbacks;

   tcp_socket socket(handle, callbacks);

   sockaddr_in address {};

   /* Attempt to connect to an address that we won't be able to connect to. */
   address.sin_family = AF_INET;
   address.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
   address.sin_port = htons(1);

   socket.connect(reinterpret_cast<const sockaddr &>(address), sizeof(address));

   afd_system *pAfd = GetCompletionAs<afd_system>(handles.iocp, INFINITE);

   EXPECT_CALL(callbacks, on_connection_failed(::testing::_, ::testing::_)).Times(1);

   pAfd->handle_events();
}

This demonstrates the event-driven nature of the code that I’m developing. In the tests we can call GetCompletionAs<> to wait for and, possibly, retrieve an IOCP completion when a poll to our \Device\Afd handle completes. In real code, this would likely be occurring in a thread pool. Once a completion occurs we simply call handle_events() on it and the afd_system object iterates the returned events and dispatches them to each socket.

There are various design questions around scalability which can be addressed later as, at present, we only support a single socket connection with the simple afd_system object. For now, I’m concentrating on the design of the tcp_socket and hoping that, when the time comes to focus on supporting multiple connections using the afd_system object we can adjust it without too much change being required to the existing code.

The socket’s connect() code is pretty simple, the underlying socket has been set up as non-blocking, and so we can simply call connect() and then set up a poll for when we either establish or fail to establish the connection. This builds on the code that we put together during the “understand” phase.

void tcp_socket::connect(
   const sockaddr &address,
   const int address_length)
{
   if (connection_state != state::created)
   {
      throw std::exception("already connected");
   }

   const int result = ::connect(s, &address, address_length);

   if (result == SOCKET_ERROR)
   {
      const DWORD lastError = WSAGetLastError();

      if (lastError != WSAEWOULDBLOCK)
      {
         throw std::exception("failed to connect");
      }
   }

   connection_state = state::pending_connect;

   events = AFD_POLL_SEND |           // writable which also means "connected"
            AFD_POLL_DISCONNECT |     // client close
            AFD_POLL_ABORT |          // closed
            AFD_POLL_LOCAL_CLOSE |    // we have closed
            AFD_POLL_CONNECT_FAIL;    // outbound connection failed

   afd.poll(events);
}

Reading and writing are also pretty easy. Our read() method returns the number of bytes that have been placed into the supplied buffer. Connection errors or readability notifications come via the callback interface. It may be useful, later, to have this code directly call the callbacks for client close and connection error events as this would remove the need to poll for them, since we know they’ve happened, but for now we keep things simple.

int tcp_socket::read(
   BYTE *pBuffer,
   int buffer_length)
{
   if (connection_state != state::connected)
   {
      throw std::exception("not connected");
   }

   int bytes = recv(s, reinterpret_cast<char *>(pBuffer), buffer_length, 0);

   if (bytes == 0)
   {
      //handle_events(AFD_POLL_DISCONNECT, 0);
   }

   if (bytes == SOCKET_ERROR)
   {
      const DWORD lastError = WSAGetLastError();

      if (lastError == WSAECONNRESET ||
          lastError == WSAECONNABORTED ||
          lastError == WSAENETRESET)
      {
         //handle_events(AFD_POLL_ABORT, 0);
      }
      else if (lastError != WSAEWOULDBLOCK)
      {
         throw std::exception("failed to read");
      }

      bytes = 0;
   }

   if (bytes == 0)
   {
      events |= (AFD_POLL_RECEIVE |
                 AFD_POLL_DISCONNECT |               // client close
                 AFD_POLL_ABORT |                    // closed
                 AFD_POLL_LOCAL_CLOSE);              // we have closed

      afd.poll(events);
   }

   return bytes;
}

And writing is similar.

int tcp_socket::write(
   const BYTE *pData,
   const int data_length)
{
   if (connection_state != state::connected)
   {
      throw std::exception("not connected");
   }

   int bytes = ::send(s, reinterpret_cast<const char *>(pData), data_length, 0);

   if (bytes == SOCKET_ERROR)
   {
      const DWORD lastError = WSAGetLastError();

      if (lastError == WSAECONNRESET ||
          lastError == WSAECONNABORTED ||
          lastError == WSAENETRESET)
      {
         //handle_events(AFD_POLL_ABORT, 0);
      }
      else if (lastError != WSAEWOULDBLOCK)
      {
         throw std::exception("failed to write");
      }

      bytes = 0;
   }

   if (bytes != data_length)
   {
      events |= (AFD_POLL_SEND |
                 AFD_POLL_DISCONNECT |               // client close
                 AFD_POLL_ABORT |                    // closed
                 AFD_POLL_LOCAL_CLOSE);              // we have closed

      afd.poll(events);
   }

   return bytes;
}

Wrapping up

We now have a simple, event-driven, tcp_socket built on the \Device\Afd interface. Next time we can put together a simple client using this socket.

Code

Full source can be found here on GitHub.

This isn’t production code, error handling is simply “panic and run away”.

This code is licensed with the MIT license.

More on AFD