struct Connection [src]

An interface to either a plain or TLS connection.

Fields

stream: net.Stream
tls_client: if (!disable_tls) *std.crypto.tls.Client else voidundefined unless protocol is tls.
protocol: ProtocolThe protocol that this connection is using.
host: []u8The host that this connection is connected to.
port: u16The port that this connection is connected to.
proxied: bool = falseWhether this connection is proxied and is not directly connected.
closing: bool = falseWhether this connection is closing when we're done with it.
read_start: BufferSize = 0
read_end: BufferSize = 0
write_end: BufferSize = 0
read_buf: [buffer_size]u8 = undefined
write_buf: [buffer_size]u8 = undefined

Members

Source

pub const Connection = struct { stream: net.Stream, /// undefined unless protocol is tls. tls_client: if (!disable_tls) *std.crypto.tls.Client else void, /// The protocol that this connection is using. protocol: Protocol, /// The host that this connection is connected to. host: []u8, /// The port that this connection is connected to. port: u16, /// Whether this connection is proxied and is not directly connected. proxied: bool = false, /// Whether this connection is closing when we're done with it. closing: bool = false, read_start: BufferSize = 0, read_end: BufferSize = 0, write_end: BufferSize = 0, read_buf: [buffer_size]u8 = undefined, write_buf: [buffer_size]u8 = undefined, pub const buffer_size = std.crypto.tls.max_ciphertext_record_len; const BufferSize = std.math.IntFittingRange(0, buffer_size); pub const Protocol = enum { plain, tls }; pub fn readvDirectTls(conn: *Connection, buffers: []std.posix.iovec) ReadError!usize { return conn.tls_client.readv(conn.stream, buffers) catch |err| { // https://github.com/ziglang/zig/issues/2473 if (mem.startsWith(u8, @errorName(err), "TlsAlert")) return error.TlsAlert; switch (err) { error.TlsConnectionTruncated, error.TlsRecordOverflow, error.TlsDecodeError, error.TlsBadRecordMac, error.TlsBadLength, error.TlsIllegalParameter, error.TlsUnexpectedMessage => return error.TlsFailure, error.ConnectionTimedOut => return error.ConnectionTimedOut, error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, else => return error.UnexpectedReadFailure, } }; } pub fn readvDirect(conn: *Connection, buffers: []std.posix.iovec) ReadError!usize { if (conn.protocol == .tls) { if (disable_tls) unreachable; return conn.readvDirectTls(buffers); } return conn.stream.readv(buffers) catch |err| switch (err) { error.ConnectionTimedOut => return error.ConnectionTimedOut, error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, else => return error.UnexpectedReadFailure, }; } /// Refills the read buffer with data from the connection. pub fn fill(conn: *Connection) ReadError!void { if (conn.read_end != conn.read_start) return; var iovecs = [1]std.posix.iovec{ .{ .base = &conn.read_buf, .len = conn.read_buf.len }, }; const nread = try conn.readvDirect(&iovecs); if (nread == 0) return error.EndOfStream; conn.read_start = 0; conn.read_end = @intCast(nread); } /// Returns the current slice of buffered data. pub fn peek(conn: *Connection) []const u8 { return conn.read_buf[conn.read_start..conn.read_end]; } /// Discards the given number of bytes from the read buffer. pub fn drop(conn: *Connection, num: BufferSize) void { conn.read_start += num; } /// Reads data from the connection into the given buffer. pub fn read(conn: *Connection, buffer: []u8) ReadError!usize { const available_read = conn.read_end - conn.read_start; const available_buffer = buffer.len; if (available_read > available_buffer) { // partially read buffered data @memcpy(buffer[0..available_buffer], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]); conn.read_start += @intCast(available_buffer); return available_buffer; } else if (available_read > 0) { // fully read buffered data @memcpy(buffer[0..available_read], conn.read_buf[conn.read_start..conn.read_end]); conn.read_start += available_read; return available_read; } var iovecs = [2]std.posix.iovec{ .{ .base = buffer.ptr, .len = buffer.len }, .{ .base = &conn.read_buf, .len = conn.read_buf.len }, }; const nread = try conn.readvDirect(&iovecs); if (nread > buffer.len) { conn.read_start = 0; conn.read_end = @intCast(nread - buffer.len); return buffer.len; } return nread; } pub const ReadError = error{ TlsFailure, TlsAlert, ConnectionTimedOut, ConnectionResetByPeer, UnexpectedReadFailure, EndOfStream, }; pub const Reader = std.io.Reader(*Connection, ReadError, read); pub fn reader(conn: *Connection) Reader { return Reader{ .context = conn }; } pub fn writeAllDirectTls(conn: *Connection, buffer: []const u8) WriteError!void { return conn.tls_client.writeAll(conn.stream, buffer) catch |err| switch (err) { error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, else => return error.UnexpectedWriteFailure, }; } pub fn writeAllDirect(conn: *Connection, buffer: []const u8) WriteError!void { if (conn.protocol == .tls) { if (disable_tls) unreachable; return conn.writeAllDirectTls(buffer); } return conn.stream.writeAll(buffer) catch |err| switch (err) { error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, else => return error.UnexpectedWriteFailure, }; } /// Writes the given buffer to the connection. pub fn write(conn: *Connection, buffer: []const u8) WriteError!usize { if (conn.write_buf.len - conn.write_end < buffer.len) { try conn.flush(); if (buffer.len > conn.write_buf.len) { try conn.writeAllDirect(buffer); return buffer.len; } } @memcpy(conn.write_buf[conn.write_end..][0..buffer.len], buffer); conn.write_end += @intCast(buffer.len); return buffer.len; } /// Returns a buffer to be filled with exactly len bytes to write to the connection. pub fn allocWriteBuffer(conn: *Connection, len: BufferSize) WriteError![]u8 { if (conn.write_buf.len - conn.write_end < len) try conn.flush(); defer conn.write_end += len; return conn.write_buf[conn.write_end..][0..len]; } /// Flushes the write buffer to the connection. pub fn flush(conn: *Connection) WriteError!void { if (conn.write_end == 0) return; try conn.writeAllDirect(conn.write_buf[0..conn.write_end]); conn.write_end = 0; } pub const WriteError = error{ ConnectionResetByPeer, UnexpectedWriteFailure, }; pub const Writer = std.io.Writer(*Connection, WriteError, write); pub fn writer(conn: *Connection) Writer { return Writer{ .context = conn }; } /// Closes the connection. pub fn close(conn: *Connection, allocator: Allocator) void { if (conn.protocol == .tls) { if (disable_tls) unreachable; // try to cleanly close the TLS connection, for any server that cares. _ = conn.tls_client.writeEnd(conn.stream, "", true) catch {}; if (conn.tls_client.ssl_key_log) |key_log| key_log.file.close(); allocator.destroy(conn.tls_client); } conn.stream.close(); allocator.free(conn.host); } }