Support for PEM-encoded RSA and EC keys.

This commit is contained in:
Jonathan Bernard 2021-12-24 00:43:19 -06:00
parent d566d31ac3
commit 73189ebf15
3 changed files with 232 additions and 34 deletions

View File

@ -1,5 +1,5 @@
import std/options, std/tables import std/tables
import bearssl import bearssl, bearssl_pkey_decoder
import ../../jwt/jwa import ../../jwt/jwa
import ../../jwt/jwk import ../../jwt/jwk
@ -7,6 +7,7 @@ import ../../jwt/jwk
import ../encoding import ../encoding
import ./hash import ./hash
import ./pem
type type
EcPublicKeyObj = object EcPublicKeyObj = object
@ -25,16 +26,23 @@ func toBearSslCurveConst(curve: EcCurve): int32 =
of P384: EC_secp384r1 of P384: EC_secp384r1
of P521: EC_secp521r1 of P521: EC_secp521r1
func fromBearSslCurveConst(curve: int32): EcCurve =
result = case curve:
of EC_secp256r1: P256
of EC_secp384r1: P384
of EC_secp521r1: P521
else: P256
func initEcPublicKeyObj(curve: EcCurve, q: string): EcPublicKeyObj = func initEcPublicKeyObj(curve: EcCurve, q: string): EcPublicKeyObj =
result = EcPublicKeyObj(curve: curve, q: q) result = EcPublicKeyObj(curve: curve, q: q)
result.bearKey.curve = curve.toBearSslCurveConst result.bearKey.curve = curve.toBearSslCurveConst
result.bearKey.q = cast[ptr cuchar](result.q.cstring) result.bearKey.q = cast[ptr char](result.q.cstring)
result.bearKey.qlen = q.len result.bearKey.qlen = q.len
func initEcPrivateKeyObj(curve: EcCurve, x: string): EcPrivateKeyObj = func initEcPrivateKeyObj(curve: EcCurve, x: string): EcPrivateKeyObj =
result = EcPrivateKeyObj(curve: curve, x: x) result = EcPrivateKeyObj(curve: curve, x: x)
result.bearKey.curve = curve.toBearSslCurveConst result.bearKey.curve = curve.toBearSslCurveConst
result.bearKey.x = cast[ptr cuchar](result.x.cstring) result.bearKey.x = cast[ptr char](result.x.cstring)
result.bearKey.xlen = x.len result.bearKey.xlen = x.len
proc toEcPublicKey(jwk: JWK): EcPublicKeyObj = proc toEcPublicKey(jwk: JWK): EcPublicKeyObj =
@ -52,6 +60,15 @@ proc toEcPublicKey(jwk: JWK): EcPublicKeyObj =
curve = keyObj.crv, curve = keyObj.crv,
q = b64UrlDecode(keyObj.x)) q = b64UrlDecode(keyObj.x))
proc toEcPublicKey(pem: string): EcPublicKeyObj =
var pkCtx: PkeyDecoderContext
pkeyDecoderInit(addr pkCtx)
decodePem(pem, "EC PUBLIC KEY", addr pkCtx,
cast[BearPemDecoderCallback](pkeyDecoderPush))
let k = pkCtx.key.ec
return initEcPublicKeyObj(k.curve.fromBearSslCurveConst, $cstring(k.q))
proc toEcPrivateKey(jwk: JWK): EcPrivateKeyObj = proc toEcPrivateKey(jwk: JWK): EcPrivateKeyObj =
## Convert an ECDSA private key in JWK format to the wrapper for BearSSL's ## Convert an ECDSA private key in JWK format to the wrapper for BearSSL's
## ECPrivateKey struct. ## ECPrivateKey struct.
@ -63,6 +80,15 @@ proc toEcPrivateKey(jwk: JWK): EcPrivateKeyObj =
curve = jwk.ecPrv.crv, curve = jwk.ecPrv.crv,
x = b64UrlDecode(jwk.ecPrv.d)) x = b64UrlDecode(jwk.ecPrv.d))
proc toEcPrivateKey(pem: string): EcPrivateKeyObj =
var skCtx: SkeyDecoderContext
skeyDecoderInit(addr skCtx)
decodePem(pem, "EC PRIVATE KEY", addr skCtx,
cast[BearPemDecoderCallback](skeyDecoderPush))
let k = skCtx.key.ec
return initEcPrivateKeyObj(k.curve.fromBearSslCurveConst, $cstring(k.x))
proc getEcHashCfg(alg: JwtAlgorithm): HashCfg = proc getEcHashCfg(alg: JwtAlgorithm): HashCfg =
let hashAlg = case alg: let hashAlg = case alg:
of ES256: SHA256 of ES256: SHA256
@ -93,9 +119,9 @@ proc bearEcSign(
let sigLen = ecSignImpl( let sigLen = ecSignImpl(
addr ecAllM15, addr ecAllM15,
hashCfg.vtable, hashCfg.vtable,
cast[ptr cuchar](unsafeAddr hashed[0]), cast[ptr char](unsafeAddr hashed[0]),
unsafeAddr key.bearKey, unsafeAddr key.bearKey,
cast[ptr cuchar](addr result[0])) cast[ptr char](addr result[0]))
if sigLen == 0: raise newException(Exception, "EC signature failed") if sigLen == 0: raise newException(Exception, "EC signature failed")
result.setLen(sigLen) result.setLen(sigLen)
@ -112,10 +138,10 @@ proc bearEcVerify(
let ecVerifyImpl = ecdsaVrfyRawGetDefault() let ecVerifyImpl = ecdsaVrfyRawGetDefault()
let resultCode = ecVerifyImpl( let resultCode = ecVerifyImpl(
addr ecAllM15, addr ecAllM15,
cast[ptr cuchar](unsafeAddr hashed[0]), cast[ptr char](unsafeAddr hashed[0]),
hashed.len, hashed.len,
unsafeAddr key.bearKey, unsafeAddr key.bearKey,
cast[ptr cuchar](unsafeAddr signature[0]), cast[ptr char](unsafeAddr signature[0]),
signature.len) signature.len)
result = resultCode == 1 result = resultCode == 1

View File

@ -0,0 +1,106 @@
import bearssl
# Taken from nim-bearssl/decls.nim
{.pragma: bearSslFunc, cdecl, gcsafe, noSideEffect, raises: [].}
type BearPemDecoderCallback* = proc(keyCtx: pointer, data: pointer, dataLen: int) {.bearSslFunc.}
proc decodePem*(
pem: string,
expectedObjectName: string,
keyCtx: pointer,
callback: BearPemDecoderCallback
) =
var pemCtx: PemDecoderContext
pemDecoderInit(addr pemCtx)
var bytesRead = 0
var readingObj = false
while bytesRead < len(pem):
bytesRead += pemDecoderPush(addr pemCtx, unsafeAddr pem[bytesRead], len(pem))
case pemDecoderEvent(addr pemCtx):
of PEM_BEGIN_OBJ:
if readingObj:
raise newException(ValueError,
"Invalid PEM: saw a second BEGIN before seeing END.")
if pemDecoderName(addr pemCtx) != expectedObjectName:
raise newException(ValueError,
"Invalid PEM: expected BEGIN " & expectedObjectName &
" but got BEGIN " & $pemDecoderName(addr pemCtx))
readingObj = true
pemDecoderSetdest(addr pemCtx, callback, keyCtx)
of PEM_END_OBJ:
if readingObj: readingObj = false
of PEM_ERROR: raise newException(ValueError, "Invalid PEM.")
else: continue
#[
proc decodeKeyFromPem*(pem: string): JWK =
var decoderCtx: PemDecoderContext
pemDecoderInit(addr decoderCtx)
var bytesRead = 0
var skCtx: SkeyDecoderContext
var pkCtx: PkeyDecoderContext
var keyCtx: pointer = nil
var keyType: JwkKeyType
while bytesRead < len(pem):
bytesRead += pemDecoderPush(
addr decoderCtx,
unsafeAddr pem[bytesRead],
len(pem))
case pemDecoderEvent(addr decoderCtx):
of PEM_BEGIN_OBJ:
if keyCtx != nil:
raise newException(ValueError, "Invalid PEM: saw a second BEGIN")
let objName = pemDecoderName(addr decoderCtx)
if objName == "RSA PRIVATE KEY":
skeyDecoderInit(addr skCtx)
keyCtx = addr skCtx
keyType = JwkKeyType.RsaPrivate
if objName == "EC PRIVATE KEY":
skeyDecoderInit(addr skCtx)
keyCtx = addr skCtx
keyType = JwkKeyType.EcPrivate
elif objName == "RSA PUBLIC KEY":
pkeyDecoderInit(addr pkCtx)
keyCtx = addr pkCtx
keyType = JwkKeyType.RsaPublic
else:
raise newException(ValueError,
"Unrecognized PEM object: " & $pemDecoderName(addr decoderCtx))
of PEM_END_OBJ:
if keyCtx == nil:
raise newException(ValueError, "Saw END before BEGIN.")
elif keyType == JwkKeyType.RsaPrivate:
return nil
of PEM_ERROR:
raise newException(ValueError, "Invalid PEM.")
else: discard nil
proc decodeRsaPem*(pem: string): JWK =
var skCtx: SkeyDecoderContext
skeyDecoderInit(addr skCtx)
decodePem(
pem,
cast[proc(ctx: pointer, data: pointer, dataLen: int) {.bearSslFunc.}](skeyDecoderPush)
addr skCtx)
if skeyDecoderLastError(addr skCtx) != 0:
raise newException(ValueError,
"Provided PEM could not be decoded as a valid RSA private key.")
]#

View File

@ -1,5 +1,5 @@
import std/options, std/tables import std/options, std/tables
import bearssl import bearssl, bearssl_pkey_decoder
import ../../jwt/jwa import ../../jwt/jwa
import ../../jwt/jwk import ../../jwt/jwk
@ -7,6 +7,7 @@ import ../../jwt/jwk
import ../encoding import ../encoding
import ./hash import ./hash
import ./pem
type type
RsaPublicKeyObj = object RsaPublicKeyObj = object
@ -15,50 +16,87 @@ type
RsaPrivateKeyObj = object RsaPrivateKeyObj = object
p*, q*, dp*, dq*, iq*: string p*, q*, dp*, dq*, iq*: string
bearKey*: RsaPrivateKey bearKey: RsaPrivateKey
func initRsaPublicKeyObj(n, e: string): RsaPublicKeyObj = func initRsaPublicKeyObj(n, e: string): RsaPublicKeyObj =
result = RsaPublicKeyObj(n: n, e: e) result = RsaPublicKeyObj(n: n, e: e)
result.bearKey.n = cast[ptr cuchar](result.n.cstring) result.bearKey.n = cast[ptr char](result.n.cstring)
result.bearKey.nlen = result.n.len result.bearKey.nlen = result.n.len
result.bearKey.e = cast[ptr cuchar](result.e.cstring) result.bearKey.e = cast[ptr char](result.e.cstring)
result.bearKey.elen = result.e.len result.bearKey.elen = result.e.len
func initRsaPrivateKeyObj(nBitLen: int, p, q, dp, dq, iq: string): RsaPrivateKeyObj = func initRsaPrivateKeyObj(nBitLen: int, p, q, dp, dq, iq: string): RsaPrivateKeyObj =
result = RsaPrivateKeyObj(p: p, q: q, dp: dp, dq: dq, iq: iq) result = RsaPrivateKeyObj(p: p, q: q, dp: dp, dq: dq, iq: iq)
result.bearKey.nBitLen = cast[uint32](nBitLen) result.bearKey.nBitLen = cast[uint32](nBitLen)
result.bearKey.p = cast[ptr cuchar](result.p.cstring) result.bearKey.p = cast[ptr char](result.p.cstring)
result.bearKey.plen = result.p.len result.bearKey.plen = result.p.len
result.bearKey.q = cast[ptr cuchar](result.q.cstring) result.bearKey.q = cast[ptr char](result.q.cstring)
result.bearKey.qlen = result.q.len result.bearKey.qlen = result.q.len
result.bearKey.dp = cast[ptr cuchar](result.dp.cstring) result.bearKey.dp = cast[ptr char](result.dp.cstring)
result.bearKey.dplen = result.dp.len result.bearKey.dplen = result.dp.len
result.bearKey.dq = cast[ptr cuchar](result.dq.cstring) result.bearKey.dq = cast[ptr char](result.dq.cstring)
result.bearKey.dqlen = result.dq.len result.bearKey.dqlen = result.dq.len
result.bearKey.iq = cast[ptr cuchar](result.iq.cstring) result.bearKey.iq = cast[ptr char](result.iq.cstring)
result.bearKey.iqlen = result.iq.len result.bearKey.iqlen = result.iq.len
proc toRsaPublicKey(jwk: JWK): RsaPublicKeyObj = proc toRsaPublicKey(jwk: JWK): RsaPublicKeyObj =
## Convert an RSA public key in JWK format to the wrapper for BearSSL's ## Convert an RSA public key in JWK format to the wrapper for BearSSL's
## RsaPublicKey struct. ## RsaPublicKey struct.
if jwk.keyKind != JwkKeyType.RsaPublic and
jwk.keyKind != JwkKeyType.RsaPrivate:
raise newException(ValueError,
"Can not extract RSA public key from the given JWT: keyKind is " &
$jwk.keyKind)
return initRsaPublicKeyObj( return initRsaPublicKeyObj(
n = b64UrlDecode(jwk.rsaPub.n), n = b64UrlDecode(jwk.rsaPub.n),
e = b64UrlDecode(jwk.rsaPub.e)) e = b64UrlDecode(jwk.rsaPub.e))
proc toRsaPublicKey(pem: string): RsaPublicKeyObj =
var pkCtx: PkeyDecoderContext
pkeyDecoderInit(addr pkCtx)
decodePem(pem, "RSA PUBLIC KEY", addr pkCtx,
cast[BearPemDecoderCallback](pkeyDecoderPush))
let k = pkCtx.key.rsa
return initRsaPublicKeyObj($cstring(k.n), $cstring(k.e))
proc toRsaPrivateKey(jwk: JWK): RsaPrivateKeyObj = proc toRsaPrivateKey(jwk: JWK): RsaPrivateKeyObj =
## Convert an RSA private key in JWK format to the wrapper for BearSSL's ## Convert an RSA private key in JWK format to the wrapper for BearSSL's
## RsaPrivateKey struct. ## RsaPrivateKey struct.
##
## Note: there are two ways to represent an rsa private key[1][rsa-priv-rep]:
## 1. The pair modulus and private exponent pair (*n*, *d*)
## 2. The quintuple with the prime factors, reduced CRT exponents, and first
## CRT coefficient (*p*, *q*, *dP*, *dQ*, *qInv*)
##
## The JWA standard requires that a JWK representing an RSA private key
## provide rep. 1 (*n*, *d*) and optionally allows rep. 2 to allows be
## included. BearSSL strictly uses rep. 2 in it's key structures.
## **Currently, this library only supports keys that provide the parameters
## for rep. 2.** Ideally this library would also support keys only containing
## rep. 1, including the logic to derive rep. 2 from use in BearSSL. This is
## a TODO.
##
## [NIST Special Publication 800-56B][NIST.SP.800-56Br2] provides a choice of
## two algorithms to derive the prime factors *p* and *q* from the private
## exponent *d*. With the prime factors known, the calculation *dP*, *dQ*,
## and *qInv* are straighforward based on [their definitions in RFC8
## 8017][rsa-apdx-a]
##
## [rsa-priv-rep]: https://datatracker.ietf.org/doc/html/rfc8017#section-3.2
## [rsa-apdx-a]: https://datatracker.ietf.org/doc/html/rfc8017#appendix-A.1.2
## [NIST.SP.800-56Br2]: https://nvlpubs.nist.gov/nistpubs/SpecialPublications/NIST.SP.800-56Br2.pdf
# TODO: JWS spec only requires a private key to have n, e, and d, as the if jwk.keyKind != JwkKeyType.RsaPrivate:
# remainder can be computed form these (p, q, dp, dq, and qi). BearSSL raise newException(ValueError,
# requires p, q, dp, dq, and qi (it calls iq). Because of this, we currently "Can not extract RSA private key from the given JWT: keyKind is " &
# require all values to be present in JWKs for RSA privat keys. We should add $jwk.keyKind)
# the logic to compute the missing values to fully support the JWS spec.
#
# We also do not currently support keys with more than two prime factors.
let sk = jwk.rsaPrv let sk = jwk.rsaPrv
if sk.p.isNone or sk.q.isNone or sk.dp.isNone or sk.dq.isNone or sk.qi.isNone: if sk.p.isNone or sk.q.isNone or sk.dp.isNone or sk.dq.isNone or sk.qi.isNone:
raise newException(ValueError, raise newException(ValueError,
"RSA private key must have values for: n, e, d, p, q, dp, dq, and qi") "RSA private key must have values for: n, e, d, p, q, dp, dq, and qi")
@ -72,6 +110,22 @@ proc toRsaPrivateKey(jwk: JWK): RsaPrivateKeyObj =
dq = b64UrlDecode(sk.dq.get), dq = b64UrlDecode(sk.dq.get),
iq = b64UrlDecode(sk.qi.get)) iq = b64UrlDecode(sk.qi.get))
proc toRsaPrivateKey(pem: string): RsaPrivateKeyObj =
var skCtx: SkeyDecoderContext
skeyDecoderInit(addr skCtx)
decodePem(pem, "RSA PRIVATE KEY", addr skCtx,
cast[BearPemDecoderCallback](skeyDecoderPush))
let k = skCtx.key.rsa
return initRsaPrivateKeyObj(
cast[int](k.nBitLen),
$cstring(k.p),
$cstring(k.q),
$cstring(k.dp),
$cstring(k.dq),
$cstring(k.iq))
proc getRsaHashCfg(alg: JwtAlgorithm): HashCfg = proc getRsaHashCfg(alg: JwtAlgorithm): HashCfg =
let hashAlg = case alg: let hashAlg = case alg:
of RS256: SHA256 of RS256: SHA256
@ -95,11 +149,11 @@ proc bearRsaSign(
result = newString((key.bearKey.nBitLen + 7) div 8) result = newString((key.bearKey.nBitLen + 7) div 8)
let errCode = rsaSignImpl( let errCode = rsaSignImpl(
cast[ptr cuchar](hashCfg.oid), cast[ptr char](hashCfg.oid),
cast[ptr cuchar](unsafeAddr hashed[0]), cast[ptr char](unsafeAddr hashed[0]),
hashed.len, hashed.len,
unsafeAddr key.bearKey, unsafeAddr key.bearKey,
cast[ptr cuchar](addr result[0])) cast[ptr char](addr result[0]))
if errCode != 1: raise newException(Exception, "RSA signature failed") if errCode != 1: raise newException(Exception, "RSA signature failed")
@ -116,12 +170,12 @@ proc bearRsaVerify(
var recoveredHash = newString(hashCfg.size) var recoveredHash = newString(hashCfg.size)
let errCode = rsaVerifyImpl( let errCode = rsaVerifyImpl(
cast[ptr cuchar](unsafeAddr signature[0]), cast[ptr char](unsafeAddr signature[0]),
signature.len, signature.len,
cast[ptr cuchar](hashCfg.oid), cast[ptr char](hashCfg.oid),
hashed.len, hashed.len,
unsafeAddr key.bearKey, unsafeAddr key.bearKey,
cast[ptr cuchar](addr recoveredHash[0])) cast[ptr char](addr recoveredHash[0]))
if errCode != 1: return false if errCode != 1: return false
return hashed == recoveredHash return hashed == recoveredHash
@ -138,6 +192,12 @@ proc rsaSign*(message: string, alg: JwtAlgorithm, key: JWK): string =
return bearRsaSign(message, alg, toRsaPrivateKey(key)) return bearRsaSign(message, alg, toRsaPrivateKey(key))
proc rsaSign*(message: string, alg: JwtAlgorithm, pemKey: string): string =
## Sign a message using hte RSA PKCS#1 v1.5 algorithm.
##
## *pemKey* is expected to be an RSA private key in PEM format.
return bearRsaSign(message, alg, toRsaPrivateKey(pemKey))
proc rsaVerify*(message, signature: string; alg: JwtAlgorithm, key: JWK): bool = proc rsaVerify*(message, signature: string; alg: JwtAlgorithm, key: JWK): bool =
## Verify the signature for a message using PKCS#1 v1.5 algorithm. ## Verify the signature for a message using PKCS#1 v1.5 algorithm.
## ##
@ -149,3 +209,9 @@ proc rsaVerify*(message, signature: string; alg: JwtAlgorithm, key: JWK): bool =
"\"typ\"=\"" & $(key.keyKind) & "\" key.") "\"typ\"=\"" & $(key.keyKind) & "\" key.")
return bearRsaVerify(message, signature, alg, toRsaPublicKey(key)) return bearRsaVerify(message, signature, alg, toRsaPublicKey(key))
proc rsaVerify*(message, signature: string; alg: JwtAlgorithm, pemKey: string): bool =
## Verify the signature for a message using PKCS#1 v1.5 algorithm.
##
## *key* is expected to be an RSA public key in PEM format.
return bearRsaVerify(message, signature, alg, toRsaPublicKey(pemKey))