diff --git a/src/main/private/crypto/ecdsa.nim b/src/main/private/crypto/ecdsa.nim index a8a226d..dacf720 100644 --- a/src/main/private/crypto/ecdsa.nim +++ b/src/main/private/crypto/ecdsa.nim @@ -1,5 +1,5 @@ -import std/options, std/tables -import bearssl +import std/tables +import bearssl, bearssl_pkey_decoder import ../../jwt/jwa import ../../jwt/jwk @@ -7,6 +7,7 @@ import ../../jwt/jwk import ../encoding import ./hash +import ./pem type EcPublicKeyObj = object @@ -21,20 +22,27 @@ type func toBearSslCurveConst(curve: EcCurve): int32 = result = case curve: - of P256: EC_secp256r1 - of P384: EC_secp384r1 - of P521: EC_secp521r1 + of P256: EC_secp256r1 + of P384: EC_secp384r1 + of P521: EC_secp521r1 + +func fromBearSslCurveConst(curve: int32): EcCurve = + result = case curve: + of EC_secp256r1: P256 + of EC_secp384r1: P384 + of EC_secp521r1: P521 + else: P256 func initEcPublicKeyObj(curve: EcCurve, q: string): EcPublicKeyObj = result = EcPublicKeyObj(curve: curve, q: q) result.bearKey.curve = curve.toBearSslCurveConst - result.bearKey.q = cast[ptr cuchar](result.q.cstring) + result.bearKey.q = cast[ptr char](result.q.cstring) result.bearKey.qlen = q.len func initEcPrivateKeyObj(curve: EcCurve, x: string): EcPrivateKeyObj = result = EcPrivateKeyObj(curve: curve, x: x) result.bearKey.curve = curve.toBearSslCurveConst - result.bearKey.x = cast[ptr cuchar](result.x.cstring) + result.bearKey.x = cast[ptr char](result.x.cstring) result.bearKey.xlen = x.len proc toEcPublicKey(jwk: JWK): EcPublicKeyObj = @@ -52,6 +60,15 @@ proc toEcPublicKey(jwk: JWK): EcPublicKeyObj = curve = keyObj.crv, q = b64UrlDecode(keyObj.x)) +proc toEcPublicKey(pem: string): EcPublicKeyObj = + var pkCtx: PkeyDecoderContext + pkeyDecoderInit(addr pkCtx) + decodePem(pem, "EC PUBLIC KEY", addr pkCtx, + cast[BearPemDecoderCallback](pkeyDecoderPush)) + + let k = pkCtx.key.ec + return initEcPublicKeyObj(k.curve.fromBearSslCurveConst, $cstring(k.q)) + proc toEcPrivateKey(jwk: JWK): EcPrivateKeyObj = ## Convert an ECDSA private key in JWK format to the wrapper for BearSSL's ## ECPrivateKey struct. @@ -63,6 +80,15 @@ proc toEcPrivateKey(jwk: JWK): EcPrivateKeyObj = curve = jwk.ecPrv.crv, x = b64UrlDecode(jwk.ecPrv.d)) +proc toEcPrivateKey(pem: string): EcPrivateKeyObj = + var skCtx: SkeyDecoderContext + skeyDecoderInit(addr skCtx) + decodePem(pem, "EC PRIVATE KEY", addr skCtx, + cast[BearPemDecoderCallback](skeyDecoderPush)) + + let k = skCtx.key.ec + return initEcPrivateKeyObj(k.curve.fromBearSslCurveConst, $cstring(k.x)) + proc getEcHashCfg(alg: JwtAlgorithm): HashCfg = let hashAlg = case alg: of ES256: SHA256 @@ -93,9 +119,9 @@ proc bearEcSign( let sigLen = ecSignImpl( addr ecAllM15, hashCfg.vtable, - cast[ptr cuchar](unsafeAddr hashed[0]), + cast[ptr char](unsafeAddr hashed[0]), unsafeAddr key.bearKey, - cast[ptr cuchar](addr result[0])) + cast[ptr char](addr result[0])) if sigLen == 0: raise newException(Exception, "EC signature failed") result.setLen(sigLen) @@ -112,10 +138,10 @@ proc bearEcVerify( let ecVerifyImpl = ecdsaVrfyRawGetDefault() let resultCode = ecVerifyImpl( addr ecAllM15, - cast[ptr cuchar](unsafeAddr hashed[0]), + cast[ptr char](unsafeAddr hashed[0]), hashed.len, unsafeAddr key.bearKey, - cast[ptr cuchar](unsafeAddr signature[0]), + cast[ptr char](unsafeAddr signature[0]), signature.len) result = resultCode == 1 diff --git a/src/main/private/crypto/pem.nim b/src/main/private/crypto/pem.nim new file mode 100644 index 0000000..a6555af --- /dev/null +++ b/src/main/private/crypto/pem.nim @@ -0,0 +1,106 @@ +import bearssl + +# Taken from nim-bearssl/decls.nim +{.pragma: bearSslFunc, cdecl, gcsafe, noSideEffect, raises: [].} + +type BearPemDecoderCallback* = proc(keyCtx: pointer, data: pointer, dataLen: int) {.bearSslFunc.} + +proc decodePem*( + pem: string, + expectedObjectName: string, + keyCtx: pointer, + callback: BearPemDecoderCallback + ) = + + var pemCtx: PemDecoderContext + pemDecoderInit(addr pemCtx) + + var bytesRead = 0 + var readingObj = false + + while bytesRead < len(pem): + bytesRead += pemDecoderPush(addr pemCtx, unsafeAddr pem[bytesRead], len(pem)) + + case pemDecoderEvent(addr pemCtx): + + of PEM_BEGIN_OBJ: + if readingObj: + raise newException(ValueError, + "Invalid PEM: saw a second BEGIN before seeing END.") + + if pemDecoderName(addr pemCtx) != expectedObjectName: + raise newException(ValueError, + "Invalid PEM: expected BEGIN " & expectedObjectName & + " but got BEGIN " & $pemDecoderName(addr pemCtx)) + + readingObj = true + pemDecoderSetdest(addr pemCtx, callback, keyCtx) + + of PEM_END_OBJ: + if readingObj: readingObj = false + + of PEM_ERROR: raise newException(ValueError, "Invalid PEM.") + + else: continue + +#[ +proc decodeKeyFromPem*(pem: string): JWK = + var decoderCtx: PemDecoderContext + pemDecoderInit(addr decoderCtx) + + var bytesRead = 0 + var skCtx: SkeyDecoderContext + var pkCtx: PkeyDecoderContext + var keyCtx: pointer = nil + var keyType: JwkKeyType + + while bytesRead < len(pem): + bytesRead += pemDecoderPush( + addr decoderCtx, + unsafeAddr pem[bytesRead], + len(pem)) + + case pemDecoderEvent(addr decoderCtx): + of PEM_BEGIN_OBJ: + if keyCtx != nil: + raise newException(ValueError, "Invalid PEM: saw a second BEGIN") + + let objName = pemDecoderName(addr decoderCtx) + if objName == "RSA PRIVATE KEY": + skeyDecoderInit(addr skCtx) + keyCtx = addr skCtx + keyType = JwkKeyType.RsaPrivate + if objName == "EC PRIVATE KEY": + skeyDecoderInit(addr skCtx) + keyCtx = addr skCtx + keyType = JwkKeyType.EcPrivate + elif objName == "RSA PUBLIC KEY": + pkeyDecoderInit(addr pkCtx) + keyCtx = addr pkCtx + keyType = JwkKeyType.RsaPublic + else: + raise newException(ValueError, + "Unrecognized PEM object: " & $pemDecoderName(addr decoderCtx)) + + of PEM_END_OBJ: + if keyCtx == nil: + raise newException(ValueError, "Saw END before BEGIN.") + elif keyType == JwkKeyType.RsaPrivate: + return nil + + of PEM_ERROR: + raise newException(ValueError, "Invalid PEM.") + + else: discard nil + +proc decodeRsaPem*(pem: string): JWK = + var skCtx: SkeyDecoderContext + skeyDecoderInit(addr skCtx) + decodePem( + pem, + cast[proc(ctx: pointer, data: pointer, dataLen: int) {.bearSslFunc.}](skeyDecoderPush) + addr skCtx) + if skeyDecoderLastError(addr skCtx) != 0: + raise newException(ValueError, + "Provided PEM could not be decoded as a valid RSA private key.") +]# diff --git a/src/main/private/crypto/rsa.nim b/src/main/private/crypto/rsa.nim index 66c18ce..58ab7ef 100644 --- a/src/main/private/crypto/rsa.nim +++ b/src/main/private/crypto/rsa.nim @@ -1,5 +1,5 @@ import std/options, std/tables -import bearssl +import bearssl, bearssl_pkey_decoder import ../../jwt/jwa import ../../jwt/jwk @@ -7,6 +7,7 @@ import ../../jwt/jwk import ../encoding import ./hash +import ./pem type RsaPublicKeyObj = object @@ -15,50 +16,87 @@ type RsaPrivateKeyObj = object p*, q*, dp*, dq*, iq*: string - bearKey*: RsaPrivateKey + bearKey: RsaPrivateKey func initRsaPublicKeyObj(n, e: string): RsaPublicKeyObj = result = RsaPublicKeyObj(n: n, e: e) - result.bearKey.n = cast[ptr cuchar](result.n.cstring) + result.bearKey.n = cast[ptr char](result.n.cstring) result.bearKey.nlen = result.n.len - result.bearKey.e = cast[ptr cuchar](result.e.cstring) + result.bearKey.e = cast[ptr char](result.e.cstring) result.bearKey.elen = result.e.len func initRsaPrivateKeyObj(nBitLen: int, p, q, dp, dq, iq: string): RsaPrivateKeyObj = result = RsaPrivateKeyObj(p: p, q: q, dp: dp, dq: dq, iq: iq) result.bearKey.nBitLen = cast[uint32](nBitLen) - result.bearKey.p = cast[ptr cuchar](result.p.cstring) + result.bearKey.p = cast[ptr char](result.p.cstring) result.bearKey.plen = result.p.len - result.bearKey.q = cast[ptr cuchar](result.q.cstring) + result.bearKey.q = cast[ptr char](result.q.cstring) result.bearKey.qlen = result.q.len - result.bearKey.dp = cast[ptr cuchar](result.dp.cstring) + result.bearKey.dp = cast[ptr char](result.dp.cstring) result.bearKey.dplen = result.dp.len - result.bearKey.dq = cast[ptr cuchar](result.dq.cstring) + result.bearKey.dq = cast[ptr char](result.dq.cstring) result.bearKey.dqlen = result.dq.len - result.bearKey.iq = cast[ptr cuchar](result.iq.cstring) + result.bearKey.iq = cast[ptr char](result.iq.cstring) result.bearKey.iqlen = result.iq.len proc toRsaPublicKey(jwk: JWK): RsaPublicKeyObj = ## Convert an RSA public key in JWK format to the wrapper for BearSSL's ## RsaPublicKey struct. + + if jwk.keyKind != JwkKeyType.RsaPublic and + jwk.keyKind != JwkKeyType.RsaPrivate: + raise newException(ValueError, + "Can not extract RSA public key from the given JWT: keyKind is " & + $jwk.keyKind) + return initRsaPublicKeyObj( n = b64UrlDecode(jwk.rsaPub.n), e = b64UrlDecode(jwk.rsaPub.e)) + +proc toRsaPublicKey(pem: string): RsaPublicKeyObj = + var pkCtx: PkeyDecoderContext + pkeyDecoderInit(addr pkCtx) + decodePem(pem, "RSA PUBLIC KEY", addr pkCtx, + cast[BearPemDecoderCallback](pkeyDecoderPush)) + + let k = pkCtx.key.rsa + return initRsaPublicKeyObj($cstring(k.n), $cstring(k.e)) + + proc toRsaPrivateKey(jwk: JWK): RsaPrivateKeyObj = ## Convert an RSA private key in JWK format to the wrapper for BearSSL's ## RsaPrivateKey struct. + ## + ## Note: there are two ways to represent an rsa private key[1][rsa-priv-rep]: + ## 1. The pair modulus and private exponent pair (*n*, *d*) + ## 2. The quintuple with the prime factors, reduced CRT exponents, and first + ## CRT coefficient (*p*, *q*, *dP*, *dQ*, *qInv*) + ## + ## The JWA standard requires that a JWK representing an RSA private key + ## provide rep. 1 (*n*, *d*) and optionally allows rep. 2 to allows be + ## included. BearSSL strictly uses rep. 2 in it's key structures. + ## **Currently, this library only supports keys that provide the parameters + ## for rep. 2.** Ideally this library would also support keys only containing + ## rep. 1, including the logic to derive rep. 2 from use in BearSSL. This is + ## a TODO. + ## + ## [NIST Special Publication 800-56B][NIST.SP.800-56Br2] provides a choice of + ## two algorithms to derive the prime factors *p* and *q* from the private + ## exponent *d*. With the prime factors known, the calculation *dP*, *dQ*, + ## and *qInv* are straighforward based on [their definitions in RFC8 + ## 8017][rsa-apdx-a] + ## + ## [rsa-priv-rep]: https://datatracker.ietf.org/doc/html/rfc8017#section-3.2 + ## [rsa-apdx-a]: https://datatracker.ietf.org/doc/html/rfc8017#appendix-A.1.2 + ## [NIST.SP.800-56Br2]: https://nvlpubs.nist.gov/nistpubs/SpecialPublications/NIST.SP.800-56Br2.pdf - # TODO: JWS spec only requires a private key to have n, e, and d, as the - # remainder can be computed form these (p, q, dp, dq, and qi). BearSSL - # requires p, q, dp, dq, and qi (it calls iq). Because of this, we currently - # require all values to be present in JWKs for RSA privat keys. We should add - # the logic to compute the missing values to fully support the JWS spec. - # - # We also do not currently support keys with more than two prime factors. + if jwk.keyKind != JwkKeyType.RsaPrivate: + raise newException(ValueError, + "Can not extract RSA private key from the given JWT: keyKind is " & + $jwk.keyKind) let sk = jwk.rsaPrv - if sk.p.isNone or sk.q.isNone or sk.dp.isNone or sk.dq.isNone or sk.qi.isNone: raise newException(ValueError, "RSA private key must have values for: n, e, d, p, q, dp, dq, and qi") @@ -72,6 +110,22 @@ proc toRsaPrivateKey(jwk: JWK): RsaPrivateKeyObj = dq = b64UrlDecode(sk.dq.get), iq = b64UrlDecode(sk.qi.get)) +proc toRsaPrivateKey(pem: string): RsaPrivateKeyObj = + var skCtx: SkeyDecoderContext + skeyDecoderInit(addr skCtx) + decodePem(pem, "RSA PRIVATE KEY", addr skCtx, + cast[BearPemDecoderCallback](skeyDecoderPush)) + + let k = skCtx.key.rsa + + return initRsaPrivateKeyObj( + cast[int](k.nBitLen), + $cstring(k.p), + $cstring(k.q), + $cstring(k.dp), + $cstring(k.dq), + $cstring(k.iq)) + proc getRsaHashCfg(alg: JwtAlgorithm): HashCfg = let hashAlg = case alg: of RS256: SHA256 @@ -95,11 +149,11 @@ proc bearRsaSign( result = newString((key.bearKey.nBitLen + 7) div 8) let errCode = rsaSignImpl( - cast[ptr cuchar](hashCfg.oid), - cast[ptr cuchar](unsafeAddr hashed[0]), + cast[ptr char](hashCfg.oid), + cast[ptr char](unsafeAddr hashed[0]), hashed.len, unsafeAddr key.bearKey, - cast[ptr cuchar](addr result[0])) + cast[ptr char](addr result[0])) if errCode != 1: raise newException(Exception, "RSA signature failed") @@ -116,12 +170,12 @@ proc bearRsaVerify( var recoveredHash = newString(hashCfg.size) let errCode = rsaVerifyImpl( - cast[ptr cuchar](unsafeAddr signature[0]), + cast[ptr char](unsafeAddr signature[0]), signature.len, - cast[ptr cuchar](hashCfg.oid), + cast[ptr char](hashCfg.oid), hashed.len, unsafeAddr key.bearKey, - cast[ptr cuchar](addr recoveredHash[0])) + cast[ptr char](addr recoveredHash[0])) if errCode != 1: return false return hashed == recoveredHash @@ -138,6 +192,12 @@ proc rsaSign*(message: string, alg: JwtAlgorithm, key: JWK): string = return bearRsaSign(message, alg, toRsaPrivateKey(key)) +proc rsaSign*(message: string, alg: JwtAlgorithm, pemKey: string): string = + ## Sign a message using hte RSA PKCS#1 v1.5 algorithm. + ## + ## *pemKey* is expected to be an RSA private key in PEM format. + return bearRsaSign(message, alg, toRsaPrivateKey(pemKey)) + proc rsaVerify*(message, signature: string; alg: JwtAlgorithm, key: JWK): bool = ## Verify the signature for a message using PKCS#1 v1.5 algorithm. ## @@ -149,3 +209,9 @@ proc rsaVerify*(message, signature: string; alg: JwtAlgorithm, key: JWK): bool = "\"typ\"=\"" & $(key.keyKind) & "\" key.") return bearRsaVerify(message, signature, alg, toRsaPublicKey(key)) + +proc rsaVerify*(message, signature: string; alg: JwtAlgorithm, pemKey: string): bool = + ## Verify the signature for a message using PKCS#1 v1.5 algorithm. + ## + ## *key* is expected to be an RSA public key in PEM format. + return bearRsaVerify(message, signature, alg, toRsaPublicKey(pemKey))