struct Connection [src]

Fields

client: *Client
stream_writer: net.Stream.Writer
stream_reader: net.Stream.Reader
pool_node: std.DoublyLinkedList.NodeEntry in ConnectionPool.used or ConnectionPool.free.
port: u16
host_len: u8
proxied: bool
closing: bool
protocol: Protocol

Members

Source

pub const Connection = struct { client: *Client, stream_writer: net.Stream.Writer, stream_reader: net.Stream.Reader, /// Entry in `ConnectionPool.used` or `ConnectionPool.free`. pool_node: std.DoublyLinkedList.Node, port: u16, host_len: u8, proxied: bool, closing: bool, protocol: Protocol, const Plain = struct { connection: Connection, fn create( client: *Client, remote_host: []const u8, port: u16, stream: net.Stream, ) error{OutOfMemory}!*Plain { const gpa = client.allocator; const alloc_len = allocLen(client, remote_host.len); const base = try gpa.alignedAlloc(u8, .of(Plain), alloc_len); errdefer gpa.free(base); const host_buffer = base[@sizeOf(Plain)..][0..remote_host.len]; const socket_read_buffer = host_buffer.ptr[host_buffer.len..][0..client.read_buffer_size]; const socket_write_buffer = socket_read_buffer.ptr[socket_read_buffer.len..][0..client.write_buffer_size]; assert(base.ptr + alloc_len == socket_write_buffer.ptr + socket_write_buffer.len); @memcpy(host_buffer, remote_host); const plain: *Plain = @ptrCast(base); plain.* = .{ .connection = .{ .client = client, .stream_writer = stream.writer(socket_write_buffer), .stream_reader = stream.reader(socket_read_buffer), .pool_node = .{}, .port = port, .host_len = @intCast(remote_host.len), .proxied = false, .closing = false, .protocol = .plain, }, }; return plain; } fn destroy(plain: *Plain) void { const c = &plain.connection; const gpa = c.client.allocator; const base: [*]align(@alignOf(Plain)) u8 = @ptrCast(plain); gpa.free(base[0..allocLen(c.client, c.host_len)]); } fn allocLen(client: *Client, host_len: usize) usize { return @sizeOf(Plain) + host_len + client.read_buffer_size + client.write_buffer_size; } fn host(plain: *Plain) []u8 { const base: [*]u8 = @ptrCast(plain); return base[@sizeOf(Plain)..][0..plain.connection.host_len]; } }; const Tls = struct { client: std.crypto.tls.Client, connection: Connection, fn create( client: *Client, remote_host: []const u8, port: u16, stream: net.Stream, ) error{ OutOfMemory, TlsInitializationFailed }!*Tls { const gpa = client.allocator; const alloc_len = allocLen(client, remote_host.len); const base = try gpa.alignedAlloc(u8, .of(Tls), alloc_len); errdefer gpa.free(base); const host_buffer = base[@sizeOf(Tls)..][0..remote_host.len]; // The TLS client wants enough buffer for the max encrypted frame // size, and the HTTP body reader wants enough buffer for the // entire HTTP header. This means we need a combined upper bound. const tls_read_buffer_len = client.tls_buffer_size + client.read_buffer_size; const tls_read_buffer = host_buffer.ptr[host_buffer.len..][0..tls_read_buffer_len]; const tls_write_buffer = tls_read_buffer.ptr[tls_read_buffer.len..][0..client.tls_buffer_size]; const socket_write_buffer = tls_write_buffer.ptr[tls_write_buffer.len..][0..client.write_buffer_size]; const socket_read_buffer = socket_write_buffer.ptr[socket_write_buffer.len..][0..client.tls_buffer_size]; assert(base.ptr + alloc_len == socket_read_buffer.ptr + socket_read_buffer.len); @memcpy(host_buffer, remote_host); const tls: *Tls = @ptrCast(base); tls.* = .{ .connection = .{ .client = client, .stream_writer = stream.writer(tls_write_buffer), .stream_reader = stream.reader(socket_read_buffer), .pool_node = .{}, .port = port, .host_len = @intCast(remote_host.len), .proxied = false, .closing = false, .protocol = .tls, }, // TODO data race here on ca_bundle if the user sets next_https_rescan_certs to true .client = std.crypto.tls.Client.init( tls.connection.stream_reader.interface(), &tls.connection.stream_writer.interface, .{ .host = .{ .explicit = remote_host }, .ca = .{ .bundle = client.ca_bundle }, .ssl_key_log = client.ssl_key_log, .read_buffer = tls_read_buffer, .write_buffer = socket_write_buffer, // This is appropriate for HTTPS because the HTTP headers contain // the content length which is used to detect truncation attacks. .allow_truncation_attacks = true, }, ) catch return error.TlsInitializationFailed, }; return tls; } fn destroy(tls: *Tls) void { const c = &tls.connection; const gpa = c.client.allocator; const base: [*]align(@alignOf(Tls)) u8 = @ptrCast(tls); gpa.free(base[0..allocLen(c.client, c.host_len)]); } fn allocLen(client: *Client, host_len: usize) usize { const tls_read_buffer_len = client.tls_buffer_size + client.read_buffer_size; return @sizeOf(Tls) + host_len + tls_read_buffer_len + client.tls_buffer_size + client.write_buffer_size + client.tls_buffer_size; } fn host(tls: *Tls) []u8 { const base: [*]u8 = @ptrCast(tls); return base[@sizeOf(Tls)..][0..tls.connection.host_len]; } }; pub const ReadError = std.crypto.tls.Client.ReadError || std.net.Stream.ReadError; pub fn getReadError(c: *const Connection) ?ReadError { return switch (c.protocol) { .tls => { if (disable_tls) unreachable; const tls: *const Tls = @alignCast(@fieldParentPtr("connection", c)); return tls.client.read_err orelse c.stream_reader.getError(); }, .plain => { return c.stream_reader.getError(); }, }; } fn getStream(c: *Connection) net.Stream { return c.stream_reader.getStream(); } pub fn host(c: *Connection) []u8 { return switch (c.protocol) { .tls => { if (disable_tls) unreachable; const tls: *Tls = @alignCast(@fieldParentPtr("connection", c)); return tls.host(); }, .plain => { const plain: *Plain = @alignCast(@fieldParentPtr("connection", c)); return plain.host(); }, }; } /// If this is called without calling `flush` or `end`, data will be /// dropped unsent. pub fn destroy(c: *Connection) void { c.getStream().close(); switch (c.protocol) { .tls => { if (disable_tls) unreachable; const tls: *Tls = @alignCast(@fieldParentPtr("connection", c)); tls.destroy(); }, .plain => { const plain: *Plain = @alignCast(@fieldParentPtr("connection", c)); plain.destroy(); }, } } /// HTTP protocol from client to server. /// This either goes directly to `stream_writer`, or to a TLS client. pub fn writer(c: *Connection) *Writer { return switch (c.protocol) { .tls => { if (disable_tls) unreachable; const tls: *Tls = @alignCast(@fieldParentPtr("connection", c)); return &tls.client.writer; }, .plain => &c.stream_writer.interface, }; } /// HTTP protocol from server to client. /// This either comes directly from `stream_reader`, or from a TLS client. pub fn reader(c: *Connection) *Reader { return switch (c.protocol) { .tls => { if (disable_tls) unreachable; const tls: *Tls = @alignCast(@fieldParentPtr("connection", c)); return &tls.client.reader; }, .plain => c.stream_reader.interface(), }; } pub fn flush(c: *Connection) Writer.Error!void { if (c.protocol == .tls) { if (disable_tls) unreachable; const tls: *Tls = @alignCast(@fieldParentPtr("connection", c)); try tls.client.writer.flush(); } try c.stream_writer.interface.flush(); } /// If the connection is a TLS connection, sends the close_notify alert. /// /// Flushes all buffers. pub fn end(c: *Connection) Writer.Error!void { if (c.protocol == .tls) { if (disable_tls) unreachable; const tls: *Tls = @alignCast(@fieldParentPtr("connection", c)); try tls.client.end(); } try c.stream_writer.interface.flush(); } }