struct rsa [src]

Members

Source

pub const rsa = struct { const max_modulus_bits = 4096; const Uint = std.crypto.ff.Uint(max_modulus_bits); const Modulus = std.crypto.ff.Modulus(max_modulus_bits); const Fe = Modulus.Fe; /// RFC 3447 8.1 RSASSA-PSS pub const PSSSignature = struct { pub fn fromBytes(comptime modulus_len: usize, msg: []const u8) [modulus_len]u8 { var result: [modulus_len]u8 = undefined; @memcpy(result[0..msg.len], msg); @memset(result[msg.len..], 0); return result; } pub const VerifyError = EncryptError || error{InvalidSignature}; pub fn verify( comptime modulus_len: usize, sig: [modulus_len]u8, msg: []const u8, public_key: PublicKey, comptime Hash: type, ) VerifyError!void { try concatVerify(modulus_len, sig, &.{msg}, public_key, Hash); } pub fn concatVerify( comptime modulus_len: usize, sig: [modulus_len]u8, msg: []const []const u8, public_key: PublicKey, comptime Hash: type, ) VerifyError!void { const mod_bits = public_key.n.bits(); const em_dec = try encrypt(modulus_len, sig, public_key); try EMSA_PSS_VERIFY(msg, &em_dec, mod_bits - 1, Hash.digest_length, Hash); } fn EMSA_PSS_VERIFY(msg: []const []const u8, em: []const u8, emBit: usize, sLen: usize, comptime Hash: type) VerifyError!void { // 1. If the length of M is greater than the input limitation for // the hash function (2^61 - 1 octets for SHA-1), output // "inconsistent" and stop. // All the cryptographic hash functions in the standard library have a limit of >= 2^61 - 1. // Even then, this check is only there for paranoia. In the context of TLS certificates, emBit cannot exceed 4096. if (emBit >= 1 << 61) return error.InvalidSignature; // emLen = \ceil(emBits/8) const emLen = ((emBit - 1) / 8) + 1; std.debug.assert(emLen == em.len); // 2. Let mHash = Hash(M), an octet string of length hLen. var mHash: [Hash.digest_length]u8 = undefined; { var hasher: Hash = .init(.{}); for (msg) |part| hasher.update(part); hasher.final(&mHash); } // 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop. if (emLen < Hash.digest_length + sLen + 2) { return error.InvalidSignature; } // 4. If the rightmost octet of EM does not have hexadecimal value // 0xbc, output "inconsistent" and stop. if (em[em.len - 1] != 0xbc) { return error.InvalidSignature; } // 5. Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, // and let H be the next hLen octets. const maskedDB = em[0..(emLen - Hash.digest_length - 1)]; const h = em[(emLen - Hash.digest_length - 1)..(emLen - 1)][0..Hash.digest_length]; // 6. If the leftmost 8emLen - emBits bits of the leftmost octet in // maskedDB are not all equal to zero, output "inconsistent" and // stop. const zero_bits = emLen * 8 - emBit; var mask: u8 = maskedDB[0]; var i: usize = 0; while (i < 8 - zero_bits) : (i += 1) { mask = mask >> 1; } if (mask != 0) { return error.InvalidSignature; } // 7. Let dbMask = MGF(H, emLen - hLen - 1). const mgf_len = emLen - Hash.digest_length - 1; var mgf_out_buf: [512]u8 = undefined; if (mgf_len > mgf_out_buf.len) { // Modulus > 4096 bits return error.InvalidSignature; } const mgf_out = mgf_out_buf[0 .. ((mgf_len - 1) / Hash.digest_length + 1) * Hash.digest_length]; var dbMask = try MGF1(Hash, mgf_out, h, mgf_len); // 8. Let DB = maskedDB \xor dbMask. i = 0; while (i < dbMask.len) : (i += 1) { dbMask[i] = maskedDB[i] ^ dbMask[i]; } // 9. Set the leftmost 8emLen - emBits bits of the leftmost octet // in DB to zero. i = 0; mask = 0; while (i < 8 - zero_bits) : (i += 1) { mask = mask << 1; mask += 1; } dbMask[0] = dbMask[0] & mask; // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not // zero or if the octet at position emLen - hLen - sLen - 1 (the // leftmost position is "position 1") does not have hexadecimal // value 0x01, output "inconsistent" and stop. if (dbMask[mgf_len - sLen - 2] != 0x00) { return error.InvalidSignature; } if (dbMask[mgf_len - sLen - 1] != 0x01) { return error.InvalidSignature; } // 11. Let salt be the last sLen octets of DB. const salt = dbMask[(mgf_len - sLen)..]; // 12. Let // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ; // M' is an octet string of length 8 + hLen + sLen with eight // initial zero octets. if (sLen > Hash.digest_length) { // A seed larger than the hash length would be useless return error.InvalidSignature; } var m_p_buf: [8 + Hash.digest_length + Hash.digest_length]u8 = undefined; var m_p = m_p_buf[0 .. 8 + Hash.digest_length + sLen]; std.mem.copyForwards(u8, m_p, &([_]u8{0} ** 8)); std.mem.copyForwards(u8, m_p[8..], &mHash); std.mem.copyForwards(u8, m_p[(8 + Hash.digest_length)..], salt); // 13. Let H' = Hash(M'), an octet string of length hLen. var h_p: [Hash.digest_length]u8 = undefined; Hash.hash(m_p, &h_p, .{}); // 14. If H = H', output "consistent". Otherwise, output // "inconsistent". if (!std.mem.eql(u8, h, &h_p)) { return error.InvalidSignature; } } fn MGF1(comptime Hash: type, out: []u8, seed: *const [Hash.digest_length]u8, len: usize) ![]u8 { var counter: u32 = 0; var idx: usize = 0; var hash = seed.* ++ @as([4]u8, undefined); while (idx < len) { std.mem.writeInt(u32, hash[seed.len..][0..4], counter, .big); Hash.hash(&hash, out[idx..][0..Hash.digest_length], .{}); idx += Hash.digest_length; counter += 1; } return out[0..len]; } }; /// RFC 3447 8.2 RSASSA-PKCS1-v1_5 pub const PKCS1v1_5Signature = struct { pub fn fromBytes(comptime modulus_len: usize, msg: []const u8) [modulus_len]u8 { var result: [modulus_len]u8 = undefined; @memcpy(result[0..msg.len], msg); @memset(result[msg.len..], 0); return result; } pub const VerifyError = EncryptError || error{InvalidSignature}; pub fn verify( comptime modulus_len: usize, sig: [modulus_len]u8, msg: []const u8, public_key: PublicKey, comptime Hash: type, ) VerifyError!void { try concatVerify(modulus_len, sig, &.{msg}, public_key, Hash); } pub fn concatVerify( comptime modulus_len: usize, sig: [modulus_len]u8, msg: []const []const u8, public_key: PublicKey, comptime Hash: type, ) VerifyError!void { const em_dec = try encrypt(modulus_len, sig, public_key); const em = try EMSA_PKCS1_V1_5_ENCODE(msg, modulus_len, Hash); if (!std.mem.eql(u8, &em_dec, &em)) return error.InvalidSignature; } fn EMSA_PKCS1_V1_5_ENCODE(msg: []const []const u8, comptime emLen: usize, comptime Hash: type) VerifyError![emLen]u8 { comptime var em_index = emLen; var em: [emLen]u8 = undefined; // 1. Apply the hash function to the message M to produce a hash value // H: // // H = Hash(M). // // If the hash function outputs "message too long," output "message // too long" and stop. var hasher: Hash = .init(.{}); for (msg) |part| hasher.update(part); em_index -= Hash.digest_length; hasher.final(em[em_index..]); // 2. Encode the algorithm ID for the hash function and the hash value // into an ASN.1 value of type DigestInfo (see Appendix A.2.4) with // the Distinguished Encoding Rules (DER), where the type DigestInfo // has the syntax // // DigestInfo ::= SEQUENCE { // digestAlgorithm AlgorithmIdentifier, // digest OCTET STRING // } // // The first field identifies the hash function and the second // contains the hash value. Let T be the DER encoding of the // DigestInfo value (see the notes below) and let tLen be the length // in octets of T. const hash_der: []const u8 = &switch (Hash) { crypto.hash.Sha1 => .{ 0x30, 0x21, 0x30, 0x09, 0x06, 0x05, 0x2b, 0x0e, 0x03, 0x02, 0x1a, 0x05, 0x00, 0x04, 0x14, }, crypto.hash.sha2.Sha224 => .{ 0x30, 0x2d, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x04, 0x05, 0x00, 0x04, 0x1c, }, crypto.hash.sha2.Sha256 => .{ 0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01, 0x05, 0x00, 0x04, 0x20, }, crypto.hash.sha2.Sha384 => .{ 0x30, 0x41, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x02, 0x05, 0x00, 0x04, 0x30, }, crypto.hash.sha2.Sha512 => .{ 0x30, 0x51, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03, 0x05, 0x00, 0x04, 0x40, }, else => @compileError("unreachable"), }; em_index -= hash_der.len; @memcpy(em[em_index..][0..hash_der.len], hash_der); // 3. If emLen < tLen + 11, output "intended encoded message length too // short" and stop. // 4. Generate an octet string PS consisting of emLen - tLen - 3 octets // with hexadecimal value 0xff. The length of PS will be at least 8 // octets. em_index -= 1; @memset(em[2..em_index], 0xff); // 5. Concatenate PS, the DER encoding T, and other padding to form the // encoded message EM as // // EM = 0x00 || 0x01 || PS || 0x00 || T. em[em_index] = 0x00; em[1] = 0x01; em[0] = 0x00; // 6. Output EM. return em; } }; pub const PublicKey = struct { n: Modulus, e: Fe, pub const FromBytesError = error{CertificatePublicKeyInvalid}; pub fn fromBytes(pub_bytes: []const u8, modulus_bytes: []const u8) FromBytesError!PublicKey { // Reject modulus below 512 bits. // 512-bit RSA was factored in 1999, so this limit barely means anything, // but establish some limit now to ratchet in what we can. const _n = Modulus.fromBytes(modulus_bytes, .big) catch return error.CertificatePublicKeyInvalid; if (_n.bits() < 512) return error.CertificatePublicKeyInvalid; // Exponent must be odd and greater than 2. // Also, it must be less than 2^32 to mitigate DoS attacks. // Windows CryptoAPI doesn't support values larger than 32 bits [1], so it is // unlikely that exponents larger than 32 bits are being used for anything // Windows commonly does. // [1] https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-rsapubkey if (pub_bytes.len > 4) return error.CertificatePublicKeyInvalid; const _e = Fe.fromBytes(_n, pub_bytes, .big) catch return error.CertificatePublicKeyInvalid; if (!_e.isOdd()) return error.CertificatePublicKeyInvalid; const e_v = _e.toPrimitive(u32) catch return error.CertificatePublicKeyInvalid; if (e_v < 2) return error.CertificatePublicKeyInvalid; return .{ .n = _n, .e = _e, }; } pub const ParseDerError = der.Element.ParseError || error{CertificateFieldHasWrongDataType}; pub fn parseDer(pub_key: []const u8) ParseDerError!struct { modulus: []const u8, exponent: []const u8 } { const pub_key_seq = try der.Element.parse(pub_key, 0); if (pub_key_seq.identifier.tag != .sequence) return error.CertificateFieldHasWrongDataType; const modulus_elem = try der.Element.parse(pub_key, pub_key_seq.slice.start); if (modulus_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType; const exponent_elem = try der.Element.parse(pub_key, modulus_elem.slice.end); if (exponent_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType; // Skip over meaningless zeroes in the modulus. const modulus_raw = pub_key[modulus_elem.slice.start..modulus_elem.slice.end]; const modulus_offset = for (modulus_raw, 0..) |byte, i| { if (byte != 0) break i; } else modulus_raw.len; return .{ .modulus = modulus_raw[modulus_offset..], .exponent = pub_key[exponent_elem.slice.start..exponent_elem.slice.end], }; } }; const EncryptError = error{MessageTooLong}; fn encrypt(comptime modulus_len: usize, msg: [modulus_len]u8, public_key: PublicKey) EncryptError![modulus_len]u8 { const m = Fe.fromBytes(public_key.n, &msg, .big) catch return error.MessageTooLong; const e = public_key.n.powPublic(m, public_key.e) catch unreachable; var res: [modulus_len]u8 = undefined; e.toBytes(&res, .big) catch unreachable; return res; } }