Adventures with \Device\Afd - test driven design
I’ve been investigating the ‘sparsely documented’ \Device\Afd
interface
that lies below the Winsock2 layer. Today I use a test-driven method for building some code to make using this
API a little easier.
Using the API
In my previous posts on the \Device\Afd
interface I focused on exploring and understanding the
API itself. Since it’s ‘sparsely documented’ I used unit tests to explore and record my findings. This
has left me with some code that is easy to come back to and pick up, which is lucky since it’s been quite a
while since I last had some time to tinker with it.
The next step is to take what I learnt last time and start to build code that could actually be used for socket communication. To that end, I’ve added a “socket” project to the code on GitHub. See the last article for an explanation as to how I’m using GoogleTest and how the code is structured.
Full source can be found here on GitHub.
We’re now working with the socket
project.
This isn’t production code, error handling is simply “panic and run away”.
This code is licensed with the MIT license.
A TCP socket class
The first thing I’m going to do is write a simple tcp_socket
class that wraps up the work we need
to do that is socket and connection related. This will separate the code that deals with the
\Device\Afd
API and make it easier to test and reason about. The end goal here is to have a simple
socket that can be event driven and which will allow me to write a simple tcp client.
class tcp_socket : private afd_events
{
public:
explicit tcp_socket(
afd_handle afd,
tcp_socket_callbacks &callbacks);
~tcp_socket() override;
void connect(
const sockaddr &address,
int address_length);
int write(
const BYTE *pData,
int data_length);
int read(
BYTE *pBuffer,
int buffer_length);
void close();
enum class shutdown_how
{
receive = 0x00,
send = 0x01,
both = 0x02
};
void shutdown(
shutdown_how how);
private :
ULONG handle_events(
ULONG eventsToHandle,
NTSTATUS status) override;
const afd_handle afd;
SOCKET s;
ULONG events;
tcp_socket_callbacks &callbacks;
enum class state
{
created,
pending_connect,
connected,
disconnected
};
state connection_state;
};
The tcp_socket
object interacts with the \Device\Afd
API via the afd_events
interface, which allows
the AFD code to call into the socket and get it to deal with events; and the afd_handle
object which
is our connection to the AFD API. The socket object reports events to the code that uses it via a callback
interface, tcp_socket_callbacks
. This allows for the socket to be event-driven and asynchronous. The
connect()
method, for example, will return immediately and report success/failure via a method on the
tcp_socket_callbacks
interface. The callbacks could come from any thread, so we’ll eventually have to take
into account locking and thread-safety, but for now we’ll sketch out the design of the code using tests
and a single thread.
The first test
As is common with Test Driven Development, the first test is the most important and often exposes most of the code-under-test’s dependencies. Here’s our first test…
TEST(AFDSocket, TestConstruct)
{
const auto handles = CreateAfdAndIOCP();
afd_system afd(handles.afd);
afd_handle handle(afd, 0);
mock_tcp_socket_callbacks callbacks;
tcp_socket socket(handle, callbacks);
}
This shows that I’ve already thought about a quick and dirty way to separate the AFD code from the socket code.
I use an afd_system
which wraps the API. This afd
object is then accessed via an afd_handle
which is an
object that wraps a handle to the internals of the AFD object and, in this case allows easy manipulation of
the afd object for a specific socket, using an index into its internal data structures. Note that this kind of
stuff is likely to change. I’m just “sketching in code” here, and the important thing is not that what I end up
with is “right” just that it helps me move towards something that will eventually be “right”.
I’m using GoogleMock to mock up the socket’s callback interface as this is easy to do and doesn’t require much code.
class mock_tcp_socket_callbacks : public tcp_socket_callbacks
{
public :
MOCK_METHOD(void, on_connected, (tcp_socket &), (override));
MOCK_METHOD(void, on_connection_failed, (tcp_socket &, DWORD), (override));
MOCK_METHOD(void, on_readable, (tcp_socket &), (override));
MOCK_METHOD(void, on_readable_oob, (tcp_socket &), (override));
MOCK_METHOD(void, on_writable, (tcp_socket &), (override));
MOCK_METHOD(void, on_client_close, (tcp_socket &), (override));
MOCK_METHOD(void, on_connection_reset, (tcp_socket &), (override));
MOCK_METHOD(void, on_disconnected, (tcp_socket &), (override));
};
From this point on, the tests that I write are very similar to the tests that I used when I was working on the “understand” project, and they’re written in a similar way.
Here’s a test that shows what happens when we try and connect to an endpoint that doesn’t exist.
TEST(AFDSocket, TestConnectFail)
{
const auto handles = CreateAfdAndIOCP();
afd_system afd(handles.afd);
afd_handle handle(afd, 0);
mock_tcp_socket_callbacks callbacks;
tcp_socket socket(handle, callbacks);
sockaddr_in address {};
/* Attempt to connect to an address that we won't be able to connect to. */
address.sin_family = AF_INET;
address.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
address.sin_port = htons(1);
socket.connect(reinterpret_cast<const sockaddr &>(address), sizeof(address));
afd_system *pAfd = GetCompletionAs<afd_system>(handles.iocp, INFINITE);
EXPECT_CALL(callbacks, on_connection_failed(::testing::_, ::testing::_)).Times(1);
pAfd->handle_events();
}
This demonstrates the event-driven nature of the code that I’m developing. In the tests
we can call GetCompletionAs<>
to wait for and, possibly, retrieve an IOCP completion
when a poll to our \Device\Afd
handle completes. In real code, this would likely be
occurring in a thread pool. Once a completion occurs we simply call handle_events()
on it and the afd_system
object iterates the returned events and dispatches them
to each socket.
There are various design questions around scalability which can be addressed later
as, at present, we only support a single socket connection with the simple afd_system
object. For now, I’m concentrating on the design of the tcp_socket
and hoping that,
when the time comes to focus on supporting multiple connections using the afd_system
object we can adjust it without too much change being required to the existing code.
The socket’s connect()
code is pretty simple, the underlying socket has been set up
as non-blocking, and so we can simply call connect()
and then set up a poll for when
we either establish or fail to establish the connection. This builds on the code that
we put together during the “understand” phase.
void tcp_socket::connect(
const sockaddr &address,
const int address_length)
{
if (connection_state != state::created)
{
throw std::exception("already connected");
}
const int result = ::connect(s, &address, address_length);
if (result == SOCKET_ERROR)
{
const DWORD lastError = WSAGetLastError();
if (lastError != WSAEWOULDBLOCK)
{
throw std::exception("failed to connect");
}
}
connection_state = state::pending_connect;
events = AFD_POLL_SEND | // writable which also means "connected"
AFD_POLL_DISCONNECT | // client close
AFD_POLL_ABORT | // closed
AFD_POLL_LOCAL_CLOSE | // we have closed
AFD_POLL_CONNECT_FAIL; // outbound connection failed
afd.poll(events);
}
Reading and writing are also pretty easy. Our read()
method returns the number of
bytes that have been placed into the supplied buffer. Connection errors or readability
notifications come via the callback interface. It may be useful, later, to have this
code directly call the callbacks for client close and connection error events as this
would remove the need to poll for them, since we know they’ve happened, but for now
we keep things simple.
int tcp_socket::read(
BYTE *pBuffer,
int buffer_length)
{
if (connection_state != state::connected)
{
throw std::exception("not connected");
}
int bytes = recv(s, reinterpret_cast<char *>(pBuffer), buffer_length, 0);
if (bytes == 0)
{
//handle_events(AFD_POLL_DISCONNECT, 0);
}
if (bytes == SOCKET_ERROR)
{
const DWORD lastError = WSAGetLastError();
if (lastError == WSAECONNRESET ||
lastError == WSAECONNABORTED ||
lastError == WSAENETRESET)
{
//handle_events(AFD_POLL_ABORT, 0);
}
else if (lastError != WSAEWOULDBLOCK)
{
throw std::exception("failed to read");
}
bytes = 0;
}
if (bytes == 0)
{
events |= (AFD_POLL_RECEIVE |
AFD_POLL_DISCONNECT | // client close
AFD_POLL_ABORT | // closed
AFD_POLL_LOCAL_CLOSE); // we have closed
afd.poll(events);
}
return bytes;
}
And writing is similar.
int tcp_socket::write(
const BYTE *pData,
const int data_length)
{
if (connection_state != state::connected)
{
throw std::exception("not connected");
}
int bytes = ::send(s, reinterpret_cast<const char *>(pData), data_length, 0);
if (bytes == SOCKET_ERROR)
{
const DWORD lastError = WSAGetLastError();
if (lastError == WSAECONNRESET ||
lastError == WSAECONNABORTED ||
lastError == WSAENETRESET)
{
//handle_events(AFD_POLL_ABORT, 0);
}
else if (lastError != WSAEWOULDBLOCK)
{
throw std::exception("failed to write");
}
bytes = 0;
}
if (bytes != data_length)
{
events |= (AFD_POLL_SEND |
AFD_POLL_DISCONNECT | // client close
AFD_POLL_ABORT | // closed
AFD_POLL_LOCAL_CLOSE); // we have closed
afd.poll(events);
}
return bytes;
}
Wrapping up
We now have a simple, event-driven, tcp_socket built on the \Device\Afd
interface.
Next time we can put together a simple client using this socket.
Full source can be found here on GitHub.
This isn’t production code, error handling is simply “panic and run away”.
This code is licensed with the MIT license.