struct Connection [src]
An interface to either a plain or TLS connection.
Fields
stream: net.Stream
tls_client: if (!disable_tls) *std.crypto.tls.Client else voidundefined unless protocol is tls.
protocol: ProtocolThe protocol that this connection is using.
host: []u8The host that this connection is connected to.
port: u16The port that this connection is connected to.
proxied: bool = falseWhether this connection is proxied and is not directly connected.
closing: bool = falseWhether this connection is closing when we're done with it.
read_start: BufferSize = 0
read_end: BufferSize = 0
write_end: BufferSize = 0
read_buf: [buffer_size]u8 = undefined
write_buf: [buffer_size]u8 = undefined
Members
- allocWriteBuffer (Function)
- buffer_size (Constant)
- close (Function)
- drop (Function)
- fill (Function)
- flush (Function)
- peek (Function)
- Protocol (enum)
- read (Function)
- reader (Function)
- Reader (Type)
- ReadError (Error Set)
- readvDirect (Function)
- readvDirectTls (Function)
- write (Function)
- writeAllDirect (Function)
- writeAllDirectTls (Function)
- WriteError (Error Set)
- writer (Function)
- Writer (Type)
Source
pub const Connection = struct {
stream: net.Stream,
/// undefined unless protocol is tls.
tls_client: if (!disable_tls) *std.crypto.tls.Client else void,
/// The protocol that this connection is using.
protocol: Protocol,
/// The host that this connection is connected to.
host: []u8,
/// The port that this connection is connected to.
port: u16,
/// Whether this connection is proxied and is not directly connected.
proxied: bool = false,
/// Whether this connection is closing when we're done with it.
closing: bool = false,
read_start: BufferSize = 0,
read_end: BufferSize = 0,
write_end: BufferSize = 0,
read_buf: [buffer_size]u8 = undefined,
write_buf: [buffer_size]u8 = undefined,
pub const buffer_size = std.crypto.tls.max_ciphertext_record_len;
const BufferSize = std.math.IntFittingRange(0, buffer_size);
pub const Protocol = enum { plain, tls };
pub fn readvDirectTls(conn: *Connection, buffers: []std.posix.iovec) ReadError!usize {
return conn.tls_client.readv(conn.stream, buffers) catch |err| {
// https://github.com/ziglang/zig/issues/2473
if (mem.startsWith(u8, @errorName(err), "TlsAlert")) return error.TlsAlert;
switch (err) {
error.TlsConnectionTruncated, error.TlsRecordOverflow, error.TlsDecodeError, error.TlsBadRecordMac, error.TlsBadLength, error.TlsIllegalParameter, error.TlsUnexpectedMessage => return error.TlsFailure,
error.ConnectionTimedOut => return error.ConnectionTimedOut,
error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
else => return error.UnexpectedReadFailure,
}
};
}
pub fn readvDirect(conn: *Connection, buffers: []std.posix.iovec) ReadError!usize {
if (conn.protocol == .tls) {
if (disable_tls) unreachable;
return conn.readvDirectTls(buffers);
}
return conn.stream.readv(buffers) catch |err| switch (err) {
error.ConnectionTimedOut => return error.ConnectionTimedOut,
error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
else => return error.UnexpectedReadFailure,
};
}
/// Refills the read buffer with data from the connection.
pub fn fill(conn: *Connection) ReadError!void {
if (conn.read_end != conn.read_start) return;
var iovecs = [1]std.posix.iovec{
.{ .base = &conn.read_buf, .len = conn.read_buf.len },
};
const nread = try conn.readvDirect(&iovecs);
if (nread == 0) return error.EndOfStream;
conn.read_start = 0;
conn.read_end = @intCast(nread);
}
/// Returns the current slice of buffered data.
pub fn peek(conn: *Connection) []const u8 {
return conn.read_buf[conn.read_start..conn.read_end];
}
/// Discards the given number of bytes from the read buffer.
pub fn drop(conn: *Connection, num: BufferSize) void {
conn.read_start += num;
}
/// Reads data from the connection into the given buffer.
pub fn read(conn: *Connection, buffer: []u8) ReadError!usize {
const available_read = conn.read_end - conn.read_start;
const available_buffer = buffer.len;
if (available_read > available_buffer) { // partially read buffered data
@memcpy(buffer[0..available_buffer], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]);
conn.read_start += @intCast(available_buffer);
return available_buffer;
} else if (available_read > 0) { // fully read buffered data
@memcpy(buffer[0..available_read], conn.read_buf[conn.read_start..conn.read_end]);
conn.read_start += available_read;
return available_read;
}
var iovecs = [2]std.posix.iovec{
.{ .base = buffer.ptr, .len = buffer.len },
.{ .base = &conn.read_buf, .len = conn.read_buf.len },
};
const nread = try conn.readvDirect(&iovecs);
if (nread > buffer.len) {
conn.read_start = 0;
conn.read_end = @intCast(nread - buffer.len);
return buffer.len;
}
return nread;
}
pub const ReadError = error{
TlsFailure,
TlsAlert,
ConnectionTimedOut,
ConnectionResetByPeer,
UnexpectedReadFailure,
EndOfStream,
};
pub const Reader = std.io.Reader(*Connection, ReadError, read);
pub fn reader(conn: *Connection) Reader {
return Reader{ .context = conn };
}
pub fn writeAllDirectTls(conn: *Connection, buffer: []const u8) WriteError!void {
return conn.tls_client.writeAll(conn.stream, buffer) catch |err| switch (err) {
error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
else => return error.UnexpectedWriteFailure,
};
}
pub fn writeAllDirect(conn: *Connection, buffer: []const u8) WriteError!void {
if (conn.protocol == .tls) {
if (disable_tls) unreachable;
return conn.writeAllDirectTls(buffer);
}
return conn.stream.writeAll(buffer) catch |err| switch (err) {
error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
else => return error.UnexpectedWriteFailure,
};
}
/// Writes the given buffer to the connection.
pub fn write(conn: *Connection, buffer: []const u8) WriteError!usize {
if (conn.write_buf.len - conn.write_end < buffer.len) {
try conn.flush();
if (buffer.len > conn.write_buf.len) {
try conn.writeAllDirect(buffer);
return buffer.len;
}
}
@memcpy(conn.write_buf[conn.write_end..][0..buffer.len], buffer);
conn.write_end += @intCast(buffer.len);
return buffer.len;
}
/// Returns a buffer to be filled with exactly len bytes to write to the connection.
pub fn allocWriteBuffer(conn: *Connection, len: BufferSize) WriteError![]u8 {
if (conn.write_buf.len - conn.write_end < len) try conn.flush();
defer conn.write_end += len;
return conn.write_buf[conn.write_end..][0..len];
}
/// Flushes the write buffer to the connection.
pub fn flush(conn: *Connection) WriteError!void {
if (conn.write_end == 0) return;
try conn.writeAllDirect(conn.write_buf[0..conn.write_end]);
conn.write_end = 0;
}
pub const WriteError = error{
ConnectionResetByPeer,
UnexpectedWriteFailure,
};
pub const Writer = std.io.Writer(*Connection, WriteError, write);
pub fn writer(conn: *Connection) Writer {
return Writer{ .context = conn };
}
/// Closes the connection.
pub fn close(conn: *Connection, allocator: Allocator) void {
if (conn.protocol == .tls) {
if (disable_tls) unreachable;
// try to cleanly close the TLS connection, for any server that cares.
_ = conn.tls_client.writeEnd(conn.stream, "", true) catch {};
if (conn.tls_client.ssl_key_log) |key_log| key_log.file.close();
allocator.destroy(conn.tls_client);
}
conn.stream.close();
allocator.free(conn.host);
}
}