struct Connection [src]
Fields
client: *Client
stream_writer: net.Stream.Writer
stream_reader: net.Stream.Reader
pool_node: std.DoublyLinkedList.NodeEntry in ConnectionPool.used or ConnectionPool.free.
port: u16
host_len: u8
proxied: bool
closing: bool
protocol: Protocol
Members
Source
pub const Connection = struct {
client: *Client,
stream_writer: net.Stream.Writer,
stream_reader: net.Stream.Reader,
/// Entry in `ConnectionPool.used` or `ConnectionPool.free`.
pool_node: std.DoublyLinkedList.Node,
port: u16,
host_len: u8,
proxied: bool,
closing: bool,
protocol: Protocol,
const Plain = struct {
connection: Connection,
fn create(
client: *Client,
remote_host: []const u8,
port: u16,
stream: net.Stream,
) error{OutOfMemory}!*Plain {
const gpa = client.allocator;
const alloc_len = allocLen(client, remote_host.len);
const base = try gpa.alignedAlloc(u8, .of(Plain), alloc_len);
errdefer gpa.free(base);
const host_buffer = base[@sizeOf(Plain)..][0..remote_host.len];
const socket_read_buffer = host_buffer.ptr[host_buffer.len..][0..client.read_buffer_size];
const socket_write_buffer = socket_read_buffer.ptr[socket_read_buffer.len..][0..client.write_buffer_size];
assert(base.ptr + alloc_len == socket_write_buffer.ptr + socket_write_buffer.len);
@memcpy(host_buffer, remote_host);
const plain: *Plain = @ptrCast(base);
plain.* = .{
.connection = .{
.client = client,
.stream_writer = stream.writer(socket_write_buffer),
.stream_reader = stream.reader(socket_read_buffer),
.pool_node = .{},
.port = port,
.host_len = @intCast(remote_host.len),
.proxied = false,
.closing = false,
.protocol = .plain,
},
};
return plain;
}
fn destroy(plain: *Plain) void {
const c = &plain.connection;
const gpa = c.client.allocator;
const base: [*]align(@alignOf(Plain)) u8 = @ptrCast(plain);
gpa.free(base[0..allocLen(c.client, c.host_len)]);
}
fn allocLen(client: *Client, host_len: usize) usize {
return @sizeOf(Plain) + host_len + client.read_buffer_size + client.write_buffer_size;
}
fn host(plain: *Plain) []u8 {
const base: [*]u8 = @ptrCast(plain);
return base[@sizeOf(Plain)..][0..plain.connection.host_len];
}
};
const Tls = struct {
client: std.crypto.tls.Client,
connection: Connection,
fn create(
client: *Client,
remote_host: []const u8,
port: u16,
stream: net.Stream,
) error{ OutOfMemory, TlsInitializationFailed }!*Tls {
const gpa = client.allocator;
const alloc_len = allocLen(client, remote_host.len);
const base = try gpa.alignedAlloc(u8, .of(Tls), alloc_len);
errdefer gpa.free(base);
const host_buffer = base[@sizeOf(Tls)..][0..remote_host.len];
// The TLS client wants enough buffer for the max encrypted frame
// size, and the HTTP body reader wants enough buffer for the
// entire HTTP header. This means we need a combined upper bound.
const tls_read_buffer_len = client.tls_buffer_size + client.read_buffer_size;
const tls_read_buffer = host_buffer.ptr[host_buffer.len..][0..tls_read_buffer_len];
const tls_write_buffer = tls_read_buffer.ptr[tls_read_buffer.len..][0..client.tls_buffer_size];
const socket_write_buffer = tls_write_buffer.ptr[tls_write_buffer.len..][0..client.write_buffer_size];
const socket_read_buffer = socket_write_buffer.ptr[socket_write_buffer.len..][0..client.tls_buffer_size];
assert(base.ptr + alloc_len == socket_read_buffer.ptr + socket_read_buffer.len);
@memcpy(host_buffer, remote_host);
const tls: *Tls = @ptrCast(base);
tls.* = .{
.connection = .{
.client = client,
.stream_writer = stream.writer(tls_write_buffer),
.stream_reader = stream.reader(socket_read_buffer),
.pool_node = .{},
.port = port,
.host_len = @intCast(remote_host.len),
.proxied = false,
.closing = false,
.protocol = .tls,
},
// TODO data race here on ca_bundle if the user sets next_https_rescan_certs to true
.client = std.crypto.tls.Client.init(
tls.connection.stream_reader.interface(),
&tls.connection.stream_writer.interface,
.{
.host = .{ .explicit = remote_host },
.ca = .{ .bundle = client.ca_bundle },
.ssl_key_log = client.ssl_key_log,
.read_buffer = tls_read_buffer,
.write_buffer = socket_write_buffer,
// This is appropriate for HTTPS because the HTTP headers contain
// the content length which is used to detect truncation attacks.
.allow_truncation_attacks = true,
},
) catch return error.TlsInitializationFailed,
};
return tls;
}
fn destroy(tls: *Tls) void {
const c = &tls.connection;
const gpa = c.client.allocator;
const base: [*]align(@alignOf(Tls)) u8 = @ptrCast(tls);
gpa.free(base[0..allocLen(c.client, c.host_len)]);
}
fn allocLen(client: *Client, host_len: usize) usize {
const tls_read_buffer_len = client.tls_buffer_size + client.read_buffer_size;
return @sizeOf(Tls) + host_len + tls_read_buffer_len + client.tls_buffer_size +
client.write_buffer_size + client.tls_buffer_size;
}
fn host(tls: *Tls) []u8 {
const base: [*]u8 = @ptrCast(tls);
return base[@sizeOf(Tls)..][0..tls.connection.host_len];
}
};
pub const ReadError = std.crypto.tls.Client.ReadError || std.net.Stream.ReadError;
pub fn getReadError(c: *const Connection) ?ReadError {
return switch (c.protocol) {
.tls => {
if (disable_tls) unreachable;
const tls: *const Tls = @alignCast(@fieldParentPtr("connection", c));
return tls.client.read_err orelse c.stream_reader.getError();
},
.plain => {
return c.stream_reader.getError();
},
};
}
fn getStream(c: *Connection) net.Stream {
return c.stream_reader.getStream();
}
pub fn host(c: *Connection) []u8 {
return switch (c.protocol) {
.tls => {
if (disable_tls) unreachable;
const tls: *Tls = @alignCast(@fieldParentPtr("connection", c));
return tls.host();
},
.plain => {
const plain: *Plain = @alignCast(@fieldParentPtr("connection", c));
return plain.host();
},
};
}
/// If this is called without calling `flush` or `end`, data will be
/// dropped unsent.
pub fn destroy(c: *Connection) void {
c.getStream().close();
switch (c.protocol) {
.tls => {
if (disable_tls) unreachable;
const tls: *Tls = @alignCast(@fieldParentPtr("connection", c));
tls.destroy();
},
.plain => {
const plain: *Plain = @alignCast(@fieldParentPtr("connection", c));
plain.destroy();
},
}
}
/// HTTP protocol from client to server.
/// This either goes directly to `stream_writer`, or to a TLS client.
pub fn writer(c: *Connection) *Writer {
return switch (c.protocol) {
.tls => {
if (disable_tls) unreachable;
const tls: *Tls = @alignCast(@fieldParentPtr("connection", c));
return &tls.client.writer;
},
.plain => &c.stream_writer.interface,
};
}
/// HTTP protocol from server to client.
/// This either comes directly from `stream_reader`, or from a TLS client.
pub fn reader(c: *Connection) *Reader {
return switch (c.protocol) {
.tls => {
if (disable_tls) unreachable;
const tls: *Tls = @alignCast(@fieldParentPtr("connection", c));
return &tls.client.reader;
},
.plain => c.stream_reader.interface(),
};
}
pub fn flush(c: *Connection) Writer.Error!void {
if (c.protocol == .tls) {
if (disable_tls) unreachable;
const tls: *Tls = @alignCast(@fieldParentPtr("connection", c));
try tls.client.writer.flush();
}
try c.stream_writer.interface.flush();
}
/// If the connection is a TLS connection, sends the close_notify alert.
///
/// Flushes all buffers.
pub fn end(c: *Connection) Writer.Error!void {
if (c.protocol == .tls) {
if (disable_tls) unreachable;
const tls: *Tls = @alignCast(@fieldParentPtr("connection", c));
try tls.client.end();
}
try c.stream_writer.interface.flush();
}
}