Files
rose-ash/hosts/ocaml/lib/sx_rsa.ml
giles f8fc04840a
Some checks failed
Test, Build, and Deploy / test-build-deploy (push) Failing after 3m9s
fed-prims: Phase F — RSA-SHA256 PKCS#1 v1.5 verify, pure OCaml, RSA-2048 vector
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-18 17:32:35 +00:00

221 lines
6.6 KiB
OCaml

(** RSASSA-PKCS1-v1_5 verification with SHA-256 — pure OCaml,
WASM-safe. Self-contained minimal bignum (modexp only), a tiny
DER reader for SubjectPublicKeyInfo, and the fixed SHA-256
DigestInfo prefix. Verify only on public data — constant time
not required. Reference: RFC 8017 §8.2.2, §9.2. No deps. *)
(* ---- Minimal unsigned bignum: int array, little-endian, base 2^26 ---- *)
let bits = 26
let base = 1 lsl bits
let mask = base - 1
type bn = int array
let norm a =
let n = ref (Array.length a) in
while !n > 1 && a.(!n - 1) = 0 do decr n done;
if !n = Array.length a then a else Array.sub a 0 !n
let bzero : bn = [| 0 |]
let is_zero a = Array.length a = 1 && a.(0) = 0
let cmp a b =
let a = norm a and b = norm b in
let la = Array.length a and lb = Array.length b in
if la <> lb then compare la lb
else begin
let r = ref 0 and i = ref (la - 1) in
while !r = 0 && !i >= 0 do
if a.(!i) <> b.(!i) then r := compare a.(!i) b.(!i);
decr i
done; !r
end
let add a b =
let la = Array.length a and lb = Array.length b in
let n = (max la lb) + 1 in
let r = Array.make n 0 and carry = ref 0 in
for i = 0 to n - 1 do
let s = !carry + (if i < la then a.(i) else 0)
+ (if i < lb then b.(i) else 0) in
r.(i) <- s land mask; carry := s lsr bits
done;
norm r
let sub a b = (* requires a >= b *)
let la = Array.length a and lb = Array.length b in
let r = Array.make la 0 and borrow = ref 0 in
for i = 0 to la - 1 do
let s = a.(i) - !borrow - (if i < lb then b.(i) else 0) in
if s < 0 then (r.(i) <- s + base; borrow := 1)
else (r.(i) <- s; borrow := 0)
done;
norm r
let mul a b =
let la = Array.length a and lb = Array.length b in
let r = Array.make (la + lb) 0 in
for i = 0 to la - 1 do
let carry = ref 0 in
for j = 0 to lb - 1 do
let s = r.(i + j) + a.(i) * b.(j) + !carry in
r.(i + j) <- s land mask; carry := s lsr bits
done;
r.(i + lb) <- r.(i + lb) + !carry
done;
norm r
let numbits a =
let a = norm a in
let hi = Array.length a - 1 in
if hi = 0 && a.(0) = 0 then 0
else begin
let b = ref 0 and v = ref a.(hi) in
while !v > 0 do incr b; v := !v lsr 1 done;
hi * bits + !b
end
let bit a i =
let limb = i / bits and off = i mod bits in
if limb >= Array.length a then 0 else (a.(limb) lsr off) land 1
let bn_mod a m = (* binary long division, m > 0 *)
if cmp a m < 0 then norm a
else begin
let r = ref bzero in
for i = numbits a - 1 downto 0 do
r := add !r !r;
if bit a i = 1 then r := add !r [| 1 |];
if cmp !r m >= 0 then r := sub !r m
done;
!r
end
let powmod b0 e m =
let result = ref [| 1 |] and b = ref (bn_mod b0 m) in
for i = 0 to numbits e - 1 do
if bit e i = 1 then result := bn_mod (mul !result !b) m;
b := bn_mod (mul !b !b) m
done;
!result
let of_bytes_be (s : string) : bn =
let acc = ref bzero in
for i = 0 to String.length s - 1 do
acc := add (mul !acc [| 256 |]) [| Char.code s.[i] |]
done;
!acc
let div_small a d =
let la = Array.length a in
let q = Array.make la 0 and rem = ref 0 in
for i = la - 1 downto 0 do
let cur = (!rem lsl bits) lor a.(i) in
q.(i) <- cur / d; rem := cur mod d
done;
norm q
let to_bytes_be (a : bn) (n : int) : string =
let b = Bytes.make n '\000' in
let cur = ref (norm a) in
for i = n - 1 downto 0 do
let q = div_small !cur 256 in
let r =
let d = sub !cur (mul q [| 256 |]) in
if is_zero d then 0 else d.(0)
in
Bytes.set b i (Char.chr r);
cur := q
done;
Bytes.unsafe_to_string b
(* ---- Minimal DER reader (for SubjectPublicKeyInfo) ---- *)
exception Der of string
(* Returns (tag, content_start, content_len, next). *)
let der_tlv s pos =
if pos + 2 > String.length s then raise (Der "short");
let tag = Char.code s.[pos] in
let l0 = Char.code s.[pos + 1] in
let len, hdr =
if l0 < 0x80 then l0, 2
else begin
let nb = l0 land 0x7f in
if pos + 2 + nb > String.length s then raise (Der "short len");
let v = ref 0 in
for i = 0 to nb - 1 do
v := (!v lsl 8) lor Char.code s.[pos + 2 + i]
done;
!v, 2 + nb
end
in
(tag, pos + hdr, len, pos + hdr + len)
(* SPKI DER -> (n, e) as bignums. *)
let parse_spki (der : string) : bn * bn =
let tag, c, _l, _ = der_tlv der 0 in
if tag <> 0x30 then raise (Der "spki: outer not SEQUENCE");
(* AlgorithmIdentifier SEQUENCE — skip. *)
let _, _, _, after_alg = der_tlv der c in
(* BIT STRING. *)
let bt, bc, bl, _ = der_tlv der after_alg in
if bt <> 0x03 then raise (Der "spki: expected BIT STRING");
(* First content byte = unused bits (must be 0). *)
let rpk_start = bc + 1 in
ignore bl;
let st, sc, _, _ = der_tlv der rpk_start in
if st <> 0x30 then raise (Der "spki: RSAPublicKey not SEQUENCE");
let nt, nc, nl, after_n = der_tlv der sc in
if nt <> 0x02 then raise (Der "spki: modulus not INTEGER");
let et, ec, el, _ = der_tlv der after_n in
if et <> 0x02 then raise (Der "spki: exponent not INTEGER");
let n = of_bytes_be (String.sub der nc nl) in
let e = of_bytes_be (String.sub der ec el) in
(n, e)
(* SHA-256 DigestInfo DER prefix (RFC 8017 §9.2 note 1). *)
let sha256_digestinfo_prefix =
"\x30\x31\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x01\x05\x00\x04\x20"
let unhex h =
let n = String.length h / 2 in
let b = Bytes.create n in
for i = 0 to n - 1 do
Bytes.set b i (Char.chr (int_of_string ("0x" ^ String.sub h (2 * i) 2)))
done;
Bytes.unsafe_to_string b
(* RSASSA-PKCS1-v1_5 verify with SHA-256. Total: any malformed
input yields false (caller wraps, but be defensive here too). *)
let verify ~spki ~msg ~sig_ : bool =
try
let n, e = parse_spki spki in
let k = (numbits n + 7) / 8 in
if String.length sig_ <> k then false
else begin
let s = of_bytes_be sig_ in
if cmp s n >= 0 then false
else begin
let m = powmod s e n in
let em = to_bytes_be m k in
(* EM = 0x00 01 FF..FF 00 || DigestInfo || H *)
let h = unhex (Sx_sha2.sha256_hex msg) in
let t = sha256_digestinfo_prefix ^ h in
let tlen = String.length t in
if k < tlen + 11 then false
else begin
let ok = ref (em.[0] = '\x00' && em.[1] = '\x01') in
let ps_end = k - tlen - 1 in
for i = 2 to ps_end - 1 do
if em.[i] <> '\xff' then ok := false
done;
if em.[ps_end] <> '\x00' then ok := false;
if String.sub em (ps_end + 1) tlen <> t then ok := false;
!ok
end
end
end
with _ -> false