UDP-WG Implementation
Loading...
Searching...
No Matches
udp.h
1#pragma once
2
3#include <format> // For formatting strings.
4#include "shared.h" // For the network code.
5
6
7using namespace shared;
8
9// Forward declaration of variables so the packet can befriend the thread.
10// Forward declarations like this are identical to how you might recognize them
11// in C. If you have functions A calling function B, but function B is
12// actually defined after function A's definition, the compiler isn't smart
13// enough to read the entire file and notice the subsequent definition. To solve it,
14// we typically just have a Forward Declaration, such as:
15//
16// void B();
17//
18// void A() {B();}
19// void B() {}
20//
21// Here, because UDP is relied upon by WireGuard, which itself is relied
22// upon by Network, we cannot just include those headers, so we instead
23// forward declare the namespaces, and needed objects to that
24// the packet knows that the network thread exists.
25//
26namespace wireguard {struct config;}
27namespace network {void thread(port_t, wireguard::config wg);}
28
35namespace udp {
36
59 class packet {
60 private:
61
62 // The thread gets privileged access to the packet: IE they can change the source.
63 friend void network::thread(port_t, wireguard::config wg);
64
65 // The UDP Pseudo-Header, as per the Reference.
66 typedef struct pseudo_header {
67
68 // The address that this packet is coming from. We use this for replies
69 uint32_t src_addr = 0;
70
71 // The destination of of the packet. We use this for the Network Thread
72 // to resolve and connection to the destination.
73 uint32_t dst_addr = 0;
74
75 // The Reference stipulates a "zero" octet and protocol octet.
76 // We can very easily just use a uint16_t, as 17 = 0000 0000 0001 0001b.
77 // That, however, begs the question of what the zero octet is for. My best
78 // guess is that they didn't think they'd need more than 256 different protocol
79 // numbers, and so wanted to only use 8 bits. However, computers like working
80 // aligned to words, particularly for memory.
81 // (See: https://en.wikipedia.org/wiki/Data_structure_alignment#2)
82 // By adding a zero byte, we can not only check it for potential corruption,
83 // since it should always be zero, but also pad the pseudo-header into a neat,
84 // orderly 96 bytes, which can be divided cleanly by 2, 3, and 4.
85 uint16_t protocol = 17;
86
87 // The length of the entire packet.
88 uint16_t length = 0;
89 } pseudo_header;
90
91
92 // The UDP Header, as per the Reference.
93 typedef struct header {
94
95 // The port the sender used to transmit the packet. The Reference tells us
96 // that this value isn't strictly necessary, and can be 0'd, but we use it
97 // since it allows us to easily lookup the FD.
98 uint16_t src_port = 0;
99
100 // The port of the destination to which we send the packet.
101 uint16_t dst_port = 0;
102
103 // The length of the entire packet, in "octets". You may not recognize this term,
104 // but might deduce the "oct-" prefix to mean 8, and you'd be right. This is eight
105 // bits, otherwise known as a byte. Why do they use esoteric language? According
106 // to https://en.wikipedia.org/wiki/Octet_(computing), a "Byte" used to be platform
107 // dependent, and the octet was a network-specific, fixed-sized definition.
108 // You may also recognize this as being within the pseudo-header. It's duplicated.
109 uint16_t length = 0;
110
111 // The Checksum is the 16-bit one's complement of the one's complement sum
112 // of the content of the pseudo-header, which contains the source and
113 // destination address, a zero, the protocol, and UDP length, and the data.
114 uint16_t check = 0;
115 } header;
116
117 // Hold each part.
118 pseudo_header p = {};
119 header h = {};
120 std::string content = {};
121
122 // These are compile-time constants; the length of the entire header,
123 // and how much space is available for data. Since the length value in the
124 // header is 16 bits per the Reference, we can only have data so large
125 static constexpr uint16_t h_length = sizeof(pseudo_header) + sizeof(header);
126 static constexpr uint16_t available = UINT16_MAX - h_length;
127
128
141 template<typename T> static uint16_t checksum(const T* data, const size_t& size) {
142
143 // Create our running checksum, and buffer
144 uint16_t check = 0, buffer = 0;
145 auto array = reinterpret_cast<const uint8_t*>(data);
146
147 // Iterate through every byte of the data.
148 for (size_t x = 0; x < size; ++x, buffer <<= 8) {
149
150 // If we have iterated through 16 bits (IE 2 bytes), add that to our checksum.
151 if (x % 2 == 0) {
152 check += buffer;
153
154 // The Reference stipulates that if the data is not sized in multiples
155 // of two octets, to add zeros to the end. By using a buffer reset
156 // to 0, we accomplish this.
157 buffer = 0;
158 }
159
160 // Add the one's compliment of the current byte to the buffer,
161 // (We then shift it to make room for the next byte on each loop)
162 buffer |= ~array[x];
163 }
164
165 // Return the check once we've exhausted the data.
166 return check;
167 }
168
169
175 packet(const fd_t& fd) {
176 // Without receiving at least the header, we have no
177 // idea how large the packet is. So we just consume the
178 // packet section-by-section, constructing the final packet to return.
179 if (recv(fd, reinterpret_cast<void*>(&p), sizeof(p), 0) < 1)
180 throw std::runtime_error("Failed to receive packet");
181 if (recv(fd, reinterpret_cast<void*>(&h), sizeof(h), 0) < 1)
182 throw std::runtime_error("Failed to receive packet");
183
184 // Figure out how big the data is by removing the headers.
185 size_t length = h.length - sizeof(p) - sizeof(h);
186 if (length > available) throw std::length_error("Invalid packet!");
187
188 char buffer[length] = {};
189 if (recv(fd, reinterpret_cast<void*>(&buffer[0]), length, 0) < 1)
190 throw std::runtime_error("Failed to receive packet");
191 content = std::string(buffer, length);
192
193 // Validate the checksum.
194 auto check = checksum(&p, sizeof(p));
195 check += checksum(content.c_str(), content.length());
196 if (check != h.check) throw std::runtime_error("Checksum error!");
197 }
198
205 packet(const connection& src, const connection& dst, const std::string& data) {construct(src, dst, data);}
206
207
214 void construct(const connection& src, const connection& dst, const std::string& data) {
215 // Figure out how many bytes we can take from the data, and the packet size.
216 uint16_t used = data.length() > available ? available : data.length();
217 uint16_t length = used + h_length;
218
219 // Create our headers and content
220 p = {.src_addr = src.pair.a, .dst_addr = dst.pair.a, .length = length};
221 h = {.src_port = src.pair.p, .dst_port = dst.pair.p, .length = length};
222
223 content = data.substr(0, used);
224
225 // Compute the checksum of the pseudo-header, and the data string.
226 h.check = checksum(&p, sizeof(p));
227 h.check += checksum(content.c_str(), content.length());
228
229 // The Reference dictates that if the checksum is 0, it should be set to all 1.
230 // All 0 indicates that checksumming wasn't used.
231 if (h.check == 0) h.check = UINT16_MAX;
232 }
233
234 void set_source(const connection& src) {p.src_addr = src.pair.a; h.src_port = src.pair.p;}
235 void set_dest(const connection& dst) {p.dst_addr = dst.pair.a; h.dst_port = dst.pair.p;}
236
237 public:
238
239 packet() = default;
240
248 packet(const connection& dst, const std::string& data) {construct(self, dst, data);}
249
254 packet(const std::string& in) {
255 size_t index = 0;
256
257 // Get the header
258 memcpy(&p, in.c_str(), sizeof(p));
259 index += sizeof(p);
260 memcpy(&h, in.c_str() + index, sizeof(h));
261 index += sizeof(h);
262
263 // Figure out how big the data is by removing the headers.
264 size_t length = h.length - sizeof(p) - sizeof(h);
265 char buffer[length] = {};
266
267 memcpy(&buffer[0], in.c_str() + index, length);
268 content = std::string(buffer, length);
269 }
270
271
278 template <typename T> packet(const connection& dst, const T& data, const size_t& size) {
279 *this = packet(dst, {reinterpret_cast<const char*>(&data), size});
280 }
281
282
288 std::string str() const {
289 std::stringstream out;
290
291 // Print the pseudo-header.
292 out <<
293 "0======1======2======3======4\n" <<
294 "| PSEUDO-HEADER |\n" <<
295 "=============================\n" <<
296 std::format("| {:^25} |\n", p.src_addr) <<
297 std::format("| {:^25} |\n", p.dst_addr) <<
298 std::format("| {:^4} | {:^4} | {:^11} |\n", p.protocol >> 8, p.protocol & 0xFF, p.length) <<
299
300 // Print the header.
301 "=============================\n" <<
302 "| HEADER |\n" <<
303 "=============================\n";
304
305 out <<
306 std::format("| {:^11} | {:^11} |\n", h.src_port, h.dst_port) <<
307 std::format("| {:^11} | {:^11} |\n", h.length, h.check) <<
308
309 // Print the content, with nice formatting.
310 "=============================\n" <<
311 "| DATA |\n" <<
312 "=============================\n";
313
314 size_t x = 0;
315 out << "| " << content[x++];
316
317 // Break at every 25th character. This isn't UDP specific, it's just so the size
318 // of our string box. If the content isn't ASCII, this can mess up the
319 // rendering, but the only types of packets that contain this information
320 // is WireGuard, and those packets get decrypted before the user can
321 // see them.
322 for (; x < content.length(); ++x) {
323 if (x % 25 == 0) {
324 out << " |\n| ";
325 }
326 out << content[x];
327 }
328
329 // Pad and terminate.
330 while (x % 25 != 0) {out << " "; x++;}
331 out << " |\n";
332 out << "=============================\n";
333
334
335 return out.str();
336 }
337
338
344 template <typename T> const T cast() const {return *reinterpret_cast<const T*>(content.c_str());}
345
346
351 std::string data() const {return content;}
352
353
358 std::string buffer() const {
359 std::string buffer = {};
360 buffer.append(reinterpret_cast<const char*>(&p), sizeof(pseudo_header));
361 buffer.append(reinterpret_cast<const char*>(&h), sizeof(header));
362 buffer.append(content);
363 return buffer;
364 }
365
366
371 connection destination() const {return {.pair = {.a = p.dst_addr, .p = h.dst_port}};}
372
373
378 connection source() const {return {.pair = {.a = p.src_addr, .p = h.src_port}};}
379
380
391 static packet empty(const connection& dst) {
392 packet ret;
393 // Create our headers and content
394 ret.p = {.src_addr = 0, .dst_addr = dst.pair.a};
395 ret.h = {.src_port = 0, .dst_port = dst.pair.p};
396 return ret;
397 }
398 };
399}
A UDP Packet.
Definition udp.h:59
std::string str() const
Print the packet.
Definition udp.h:288
static packet empty(const connection &dst)
An empty packet.
Definition udp.h:391
std::string data() const
Return the data.
Definition udp.h:351
connection source() const
Get the address and port of the source.
Definition udp.h:378
std::string buffer() const
Create a buffer of the packet that can be sent across the network.
Definition udp.h:358
const T cast() const
Cast the content of the packet as a type.
Definition udp.h:344
connection destination() const
Get the address and port of the destination.
Definition udp.h:371
packet(const connection &dst, const T &data, const size_t &size)
Construct a packet from any variable data, and its size.
Definition udp.h:278
packet(const connection &dst, const std::string &data)
Create a packet from a string.
Definition udp.h:248
packet(const std::string &in)
Construct a packet from a string buffer.
Definition udp.h:254
The core networking namespace.
Definition network.h:23
void thread(port_t port, wireguard::config wg={})
The Network Thread.
Definition network.h:169
The shared namespace.
Definition shared.h:19
This namespace includes the UDP implementation,.
Definition udp.h:35
This namespace includes the WireGuard implementation,.
Definition udp.h:26
A WireGuard Configuration.
Definition wireguard.h:113
Definition shared.h:45