Make RSA_check_key more than 2x as fast.

The bulk of RSA_check_key is spent in bn_div_consttime, which is a naive
but constant-time long-division algorithm for the few places that divide
by a secret even divisor: RSA keygen and RSA import. RSA import is
somewhat performance-sensitive, so pick some low-hanging fruit:

The main observation is that, in all but one call site, the bit width of
the divisor is public. That means, for an N-bit divisor, we can skip the
first N-1 iterations of long division because an N-1-bit remainder
cannot exceed the N-bit divisor.

One minor nuisance is bn_lcm_consttime, used in RSA keygen has a case
that does *not* have a public bit width. Apply the optimization there
would leak information. I've implemented this as an optional public
lower bound on num_bits(divisor), which all but that call fills in.

Before:
Did 5060 RSA 2048 private key parse operations in 1058526us (4780.2 ops/sec)
Did 1551 RSA 4096 private key parse operations in 1082343us (1433.0 ops/sec)

After:
Did 11532 RSA 2048 private key parse operations in 1084145us (10637.0 ops/sec) [+122.5%]
Did 3542 RSA 4096 private key parse operations in 1036374us (3417.7 ops/sec) [+138.5%]

Bug: b/192484677
Change-Id: I893ebb8886aeb8200a1a365673b56c49774221a2
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/49106
Reviewed-by: Adam Langley <agl@google.com>
grpc-202302
David Benjamin 4 years ago committed by Adam Langley
parent 417010f9b7
commit c65543b7a9
  1. 10
      crypto/fipsmodule/bn/bn_test.cc
  2. 24
      crypto/fipsmodule/bn/div.c
  3. 5
      crypto/fipsmodule/bn/gcd_extra.c
  4. 7
      crypto/fipsmodule/bn/internal.h
  5. 27
      crypto/fipsmodule/rsa/rsa.c
  6. 6
      crypto/fipsmodule/rsa/rsa_impl.c

@ -613,9 +613,17 @@ static void TestQuotient(BIGNUMFileTest *t, BN_CTX *ctx) {
}
}
ASSERT_TRUE(bn_div_consttime(ret.get(), ret2.get(), a.get(), b.get(), ctx));
ASSERT_TRUE(bn_div_consttime(ret.get(), ret2.get(), a.get(), b.get(),
/*divisor_min_bits=*/0, ctx));
EXPECT_BIGNUMS_EQUAL("A / B (constant-time)", quotient.get(), ret.get());
EXPECT_BIGNUMS_EQUAL("A % B (constant-time)", remainder.get(), ret2.get());
ASSERT_TRUE(bn_div_consttime(ret.get(), ret2.get(), a.get(), b.get(),
/*divisor_min_bits=*/BN_num_bits(b.get()), ctx));
EXPECT_BIGNUMS_EQUAL("A / B (constant-time, public width)", quotient.get(),
ret.get());
EXPECT_BIGNUMS_EQUAL("A % B (constant-time, public width)", remainder.get(),
ret2.get());
}
static void TestModMul(BIGNUMFileTest *t, BN_CTX *ctx) {

@ -456,7 +456,7 @@ void bn_mod_add_words(BN_ULONG *r, const BN_ULONG *a, const BN_ULONG *b,
int bn_div_consttime(BIGNUM *quotient, BIGNUM *remainder,
const BIGNUM *numerator, const BIGNUM *divisor,
BN_CTX *ctx) {
unsigned divisor_min_bits, BN_CTX *ctx) {
if (BN_is_negative(numerator) || BN_is_negative(divisor)) {
OPENSSL_PUT_ERROR(BN, BN_R_NEGATIVE_NUMBER);
return 0;
@ -496,8 +496,26 @@ int bn_div_consttime(BIGNUM *quotient, BIGNUM *remainder,
r->neg = 0;
// Incorporate |numerator| into |r|, one bit at a time, reducing after each
// step. At the start of each loop iteration, |r| < |divisor|
for (int i = numerator->width - 1; i >= 0; i--) {
// step. We maintain the invariant that |0 <= r < divisor| and
// |q * divisor + r = n| where |n| is the portion of |numerator| incorporated
// so far.
//
// First, we short-circuit the loop: if we know |divisor| has at least
// |divisor_min_bits| bits, the top |divisor_min_bits - 1| can be incorporated
// without reductions. This significantly speeds up |RSA_check_key|. For
// simplicity, we round down to a whole number of words.
assert(divisor_min_bits <= BN_num_bits(divisor));
int initial_words = 0;
if (divisor_min_bits > 0) {
initial_words = (divisor_min_bits - 1) / BN_BITS2;
if (initial_words > numerator->width) {
initial_words = numerator->width;
}
OPENSSL_memcpy(r->d, numerator->d + numerator->width - initial_words,
initial_words * sizeof(BN_ULONG));
}
for (int i = numerator->width - initial_words - 1; i >= 0; i--) {
for (int bit = BN_BITS2 - 1; bit >= 0; bit--) {
// Incorporate the next bit of the numerator, by computing
// r = 2*r or 2*r + 1. Note the result fits in one more word. We store the

@ -157,10 +157,11 @@ int bn_lcm_consttime(BIGNUM *r, const BIGNUM *a, const BIGNUM *b, BN_CTX *ctx) {
BN_CTX_start(ctx);
unsigned shift;
BIGNUM *gcd = BN_CTX_get(ctx);
int ret = gcd != NULL &&
int ret = gcd != NULL && //
bn_mul_consttime(r, a, b, ctx) &&
bn_gcd_consttime(gcd, &shift, a, b, ctx) &&
bn_div_consttime(r, NULL, r, gcd, ctx) &&
// |gcd| has a secret bit width.
bn_div_consttime(r, NULL, r, gcd, /*divisor_min_bits=*/0, ctx) &&
bn_rshift_secret_shift(r, r, shift, ctx);
BN_CTX_end(ctx);
return ret;

@ -552,12 +552,15 @@ int bn_sqr_consttime(BIGNUM *r, const BIGNUM *a, BN_CTX *ctx);
// bn_div_consttime behaves like |BN_div|, but it rejects negative inputs and
// treats both inputs, including their magnitudes, as secret. It is, as a
// result, much slower than |BN_div| and should only be used for rare operations
// where Montgomery reduction is not available.
// where Montgomery reduction is not available. |divisor_min_bits| is a
// public lower bound for |BN_num_bits(divisor)|. When |divisor|'s bit width is
// public, this can speed up the operation.
//
// Note that |quotient->width| will be set pessimally to |numerator->width|.
OPENSSL_EXPORT int bn_div_consttime(BIGNUM *quotient, BIGNUM *remainder,
const BIGNUM *numerator,
const BIGNUM *divisor, BN_CTX *ctx);
const BIGNUM *divisor,
unsigned divisor_min_bits, BN_CTX *ctx);
// bn_is_relatively_prime checks whether GCD(|x|, |y|) is one. On success, it
// returns one and sets |*out_relatively_prime| to one if the GCD was one and

@ -657,7 +657,8 @@ err:
}
static int check_mod_inverse(int *out_ok, const BIGNUM *a, const BIGNUM *ainv,
const BIGNUM *m, BN_CTX *ctx) {
const BIGNUM *m, unsigned m_min_bits,
BN_CTX *ctx) {
if (BN_is_negative(ainv) || BN_cmp(ainv, m) >= 0) {
*out_ok = 0;
return 1;
@ -670,7 +671,7 @@ static int check_mod_inverse(int *out_ok, const BIGNUM *a, const BIGNUM *ainv,
BIGNUM *tmp = BN_CTX_get(ctx);
int ret = tmp != NULL &&
bn_mul_consttime(tmp, a, ainv, ctx) &&
bn_div_consttime(NULL, tmp, tmp, m, ctx);
bn_div_consttime(NULL, tmp, tmp, m, m_min_bits, ctx);
if (ret) {
*out_ok = BN_is_one(tmp);
}
@ -750,10 +751,15 @@ int RSA_check_key(const RSA *key) {
// simply check that d * e is one mod p-1 and mod q-1. Note d and e were bound
// by earlier checks in this function.
if (!bn_usub_consttime(&pm1, key->p, BN_value_one()) ||
!bn_usub_consttime(&qm1, key->q, BN_value_one()) ||
!bn_mul_consttime(&de, key->d, key->e, ctx) ||
!bn_div_consttime(NULL, &tmp, &de, &pm1, ctx) ||
!bn_div_consttime(NULL, &de, &de, &qm1, ctx)) {
!bn_usub_consttime(&qm1, key->q, BN_value_one())) {
OPENSSL_PUT_ERROR(RSA, ERR_LIB_BN);
goto out;
}
const unsigned pm1_bits = BN_num_bits(&pm1);
const unsigned qm1_bits = BN_num_bits(&qm1);
if (!bn_mul_consttime(&de, key->d, key->e, ctx) ||
!bn_div_consttime(NULL, &tmp, &de, &pm1, pm1_bits, ctx) ||
!bn_div_consttime(NULL, &de, &de, &qm1, qm1_bits, ctx)) {
OPENSSL_PUT_ERROR(RSA, ERR_LIB_BN);
goto out;
}
@ -772,9 +778,12 @@ int RSA_check_key(const RSA *key) {
if (has_crt_values) {
int dmp1_ok, dmq1_ok, iqmp_ok;
if (!check_mod_inverse(&dmp1_ok, key->e, key->dmp1, &pm1, ctx) ||
!check_mod_inverse(&dmq1_ok, key->e, key->dmq1, &qm1, ctx) ||
!check_mod_inverse(&iqmp_ok, key->q, key->iqmp, key->p, ctx)) {
if (!check_mod_inverse(&dmp1_ok, key->e, key->dmp1, &pm1, pm1_bits, ctx) ||
!check_mod_inverse(&dmq1_ok, key->e, key->dmq1, &qm1, qm1_bits, ctx) ||
// |p| is odd, so |pm1| and |p| have the same bit width. If they didn't,
// we only need a lower bound anyway.
!check_mod_inverse(&iqmp_ok, key->q, key->iqmp, key->p, pm1_bits,
ctx)) {
OPENSSL_PUT_ERROR(RSA, ERR_LIB_BN);
goto out;
}

@ -1262,12 +1262,14 @@ static int rsa_generate_key_impl(RSA *rsa, int bits, const BIGNUM *e_value,
// values for d.
} while (BN_cmp(rsa->d, pow2_prime_bits) <= 0);
assert(BN_num_bits(pm1) == (unsigned)prime_bits);
assert(BN_num_bits(qm1) == (unsigned)prime_bits);
if (// Calculate n.
!bn_mul_consttime(rsa->n, rsa->p, rsa->q, ctx) ||
// Calculate d mod (p-1).
!bn_div_consttime(NULL, rsa->dmp1, rsa->d, pm1, ctx) ||
!bn_div_consttime(NULL, rsa->dmp1, rsa->d, pm1, prime_bits, ctx) ||
// Calculate d mod (q-1)
!bn_div_consttime(NULL, rsa->dmq1, rsa->d, qm1, ctx)) {
!bn_div_consttime(NULL, rsa->dmq1, rsa->d, qm1, prime_bits, ctx)) {
goto bn_err;
}
bn_set_minimal_width(rsa->n);

Loading…
Cancel
Save