diff --git a/src/core/ext/transport/chttp2/transport/bin_decoder.c b/src/core/ext/transport/chttp2/transport/bin_decoder.c index fe6c84bfb8f..640c29f63d2 100644 --- a/src/core/ext/transport/chttp2/transport/bin_decoder.c +++ b/src/core/ext/transport/chttp2/transport/bin_decoder.c @@ -32,31 +32,130 @@ */ #include "src/core/ext/transport/chttp2/transport/bin_decoder.h" +#include #include #include +#include "src/core/lib/support/string.h" static uint8_t decode_table[] = { - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 62, 0, 0, 0, 63, 52, 53, 54, 55, 56, 57, 58, 59, 60, - 61, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, - 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 0, 0, 0, 0, - 0, 0, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, - 43, 44, 45, 46, 47, 48, 49, 50, 51, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0}; + 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, + 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, + 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, + 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 62, 0x40, 0x40, 0x40, 63, + 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 0x40, 0x40, + 0x40, 0x40, 0x40, 0x40, 0x40, 0, 1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 0x40, 0x40, 0x40, 0x40, 0x40, + 0x40, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, + 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, + 49, 50, 51, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, + 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, + 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, + 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, + 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, + 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, + 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, + 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, + 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, + 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, + 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, + 0x40, 0x40, 0x40, 0x40}; static const uint8_t tail_xtra[4] = {0, 0, 1, 2}; +static inline bool input_is_valid(uint8_t *input_ptr, size_t length) { + size_t i; + + for (i = 0; i < length; ++i) { + if ((decode_table[input_ptr[i]] & 0xC0) != 0) { + gpr_log(GPR_ERROR, + "Base64 decoding failed, invalid charactor '%c' in base64 " + "input.\n", + (char)(*input_ptr)); + return false; + } + } + return true; +} + +#define COMPOSE_OUTPUT_BYTE_0(input_ptr) \ + (uint8_t)((decode_table[input_ptr[0]] << 2) | \ + (decode_table[input_ptr[1]] >> 4)) + +#define COMPOSE_OUTPUT_BYTE_1(input_ptr) \ + (uint8_t)((decode_table[ctx->input_cur[1]] << 4) | \ + (decode_table[ctx->input_cur[2]] >> 2)) + +#define COMPOSE_OUTPUT_BYTE_2(input_ptr) \ + (uint8_t)((decode_table[ctx->input_cur[2]] << 6) | \ + decode_table[ctx->input_cur[3]]) + +bool grpc_base64_decode_partial(struct grpc_base64_decode_context *ctx) { + size_t input_tail; + + if (ctx->input_cur > ctx->input_end || ctx->output_cur > ctx->output_end) { + return false; + } + + while (ctx->input_end >= ctx->input_cur + 4 && + ctx->output_end >= ctx->output_cur + 3) { + if (!input_is_valid(ctx->input_cur, 4)) return false; + ctx->output_cur[0] = COMPOSE_OUTPUT_BYTE_0(ctx->input_cur); + ctx->output_cur[1] = COMPOSE_OUTPUT_BYTE_1(ctx->input_cur); + ctx->output_cur[2] = COMPOSE_OUTPUT_BYTE_2(ctx->input_cur); + ctx->output_cur += 3; + ctx->input_cur += 4; + } + + input_tail = (size_t)(ctx->input_end - ctx->input_cur); + if (input_tail == 4) { + // Process the input data with pad chars + if (ctx->input_cur[3] == '=') { + if (ctx->input_cur[2] == '=' && ctx->output_end >= ctx->output_cur + 1) { + if (!input_is_valid(ctx->input_cur, 2)) return false; + *(ctx->output_cur++) = COMPOSE_OUTPUT_BYTE_0(ctx->input_cur); + ctx->input_cur += 4; + } else if (ctx->output_end >= ctx->output_cur + 2) { + if (!input_is_valid(ctx->input_cur, 3)) return false; + *(ctx->output_cur++) = COMPOSE_OUTPUT_BYTE_0(ctx->input_cur); + *(ctx->output_cur++) = COMPOSE_OUTPUT_BYTE_1(ctx->input_cur); + ; + ctx->input_cur += 4; + } + } + + } else if (ctx->contains_tail && input_tail > 1) { + // Process the input data without pad chars, but constains_tail is set + if (ctx->output_end >= ctx->output_cur + tail_xtra[input_tail]) { + if (!input_is_valid(ctx->input_cur, input_tail)) return false; + switch (input_tail) { + case 3: + ctx->output_cur[1] = COMPOSE_OUTPUT_BYTE_1(ctx->input_cur); + case 2: + ctx->output_cur[0] = COMPOSE_OUTPUT_BYTE_0(ctx->input_cur); + } + ctx->output_cur += tail_xtra[input_tail]; + ctx->input_cur += input_tail; + } + } + + return true; +} + gpr_slice grpc_chttp2_base64_decode(gpr_slice input) { size_t input_length = GPR_SLICE_LENGTH(input); - GPR_ASSERT(input_length % 4 == 0); size_t output_length = input_length / 4 * 3; + struct grpc_base64_decode_context ctx; + gpr_slice output; + + if (input_length % 4 != 0) { + gpr_log(GPR_ERROR, + "Base64 decoding failed, input of " + "grpc_chttp2_base64_decode has a length of %zu, which is not a " + "multiple of 4.\n", + input_length); + return gpr_empty_slice(); + } if (input_length > 0) { uint8_t *input_end = GPR_SLICE_END_PTR(input); @@ -67,49 +166,66 @@ gpr_slice grpc_chttp2_base64_decode(gpr_slice input) { } } } + output = gpr_slice_malloc(output_length); - gpr_log(GPR_ERROR, "input_length: %d, output_length: %d\n", input_length, - output_length); + ctx.input_cur = GPR_SLICE_START_PTR(input); + ctx.input_end = GPR_SLICE_END_PTR(input); + ctx.output_cur = GPR_SLICE_START_PTR(output); + ctx.output_end = GPR_SLICE_END_PTR(output); + ctx.contains_tail = false; - return grpc_chttp2_base64_decode_with_length(input, output_length); + if (!grpc_base64_decode_partial(&ctx)) { + char *s = gpr_dump_slice(input, GPR_DUMP_ASCII); + gpr_log(GPR_ERROR, "Base64 decoding failed, input string:\n%s\n", s); + gpr_free(s); + gpr_slice_unref(output); + return gpr_empty_slice(); + } + GPR_ASSERT(ctx.output_cur == GPR_SLICE_END_PTR(output)); + GPR_ASSERT(ctx.input_cur == GPR_SLICE_END_PTR(input)); + return output; } gpr_slice grpc_chttp2_base64_decode_with_length(gpr_slice input, size_t output_length) { size_t input_length = GPR_SLICE_LENGTH(input); - // The length of a base64 string cannot be 4 * n + 1 - GPR_ASSERT(input_length % 4 != 1); - GPR_ASSERT(output_length <= - input_length / 4 * 3 + tail_xtra[input_length % 4]); - size_t output_triplets = output_length / 3; - size_t tail_case = output_length % 3; gpr_slice output = gpr_slice_malloc(output_length); - uint8_t *in = GPR_SLICE_START_PTR(input); - uint8_t *out = GPR_SLICE_START_PTR(output); - size_t i; + struct grpc_base64_decode_context ctx; - for (i = 0; i < output_triplets; i++) { - out[0] = (uint8_t)((decode_table[in[0]] << 2) | (decode_table[in[1]] >> 4)); - out[1] = (uint8_t)((decode_table[in[1]] << 4) | (decode_table[in[2]] >> 2)); - out[2] = (uint8_t)((decode_table[in[2]] << 6) | decode_table[in[3]]); - out += 3; - in += 4; + // The length of a base64 string cannot be 4 * n + 1 + if (input_length % 4 == 1) { + gpr_log(GPR_ERROR, + "Base64 decoding failed, input of " + "grpc_chttp2_base64_decode_with_length has a length of %zu, which " + "has a tail of 1 byte.\n", + input_length); + gpr_slice_unref(output); + return gpr_empty_slice(); } - if (tail_case > 0) { - switch (tail_case) { - case 2: - out[1] = - (uint8_t)((decode_table[in[1]] << 4) | (decode_table[in[2]] >> 2)); - case 1: - out[0] = - (uint8_t)((decode_table[in[0]] << 2) | (decode_table[in[1]] >> 4)); - } - out += tail_case; - in += tail_case + 1; + if (output_length > input_length / 4 * 3 + tail_xtra[input_length % 4]) { + gpr_log(GPR_ERROR, + "Base64 decoding failed, output_length %zu is longer " + "than the max possible output length %zu./\n", + output_length, input_length / 4 * 3 + tail_xtra[input_length % 4]); + gpr_slice_unref(output); + return gpr_empty_slice(); } - GPR_ASSERT(out == GPR_SLICE_END_PTR(output)); - GPR_ASSERT(in <= GPR_SLICE_END_PTR(input)); + ctx.input_cur = GPR_SLICE_START_PTR(input); + ctx.input_end = GPR_SLICE_END_PTR(input); + ctx.output_cur = GPR_SLICE_START_PTR(output); + ctx.output_end = GPR_SLICE_END_PTR(output); + ctx.contains_tail = true; + + if (!grpc_base64_decode_partial(&ctx)) { + char *s = gpr_dump_slice(input, GPR_DUMP_ASCII); + gpr_log(GPR_ERROR, "Base64 decoding failed, input string:\n%s\n", s); + gpr_free(s); + gpr_slice_unref(output); + return gpr_empty_slice(); + } + GPR_ASSERT(ctx.output_cur == GPR_SLICE_END_PTR(output)); + GPR_ASSERT(ctx.input_cur <= GPR_SLICE_END_PTR(input)); return output; } diff --git a/src/core/ext/transport/chttp2/transport/bin_decoder.h b/src/core/ext/transport/chttp2/transport/bin_decoder.h index 5516f86d53b..b9d40c9b74b 100644 --- a/src/core/ext/transport/chttp2/transport/bin_decoder.h +++ b/src/core/ext/transport/chttp2/transport/bin_decoder.h @@ -35,13 +35,31 @@ #define GRPC_CORE_EXT_TRANSPORT_CHTTP2_TRANSPORT_BIN_DECODER_H #include +#include + +struct grpc_base64_decode_context { + /* input/output: */ + uint8_t *input_cur; + uint8_t *input_end; + uint8_t *output_cur; + uint8_t *output_end; + /* Indicate if the decoder should handle the tail of input data*/ + bool contains_tail; +}; + +/* base64 decode a grpc_base64_decode_context util either input_end is reached + or output_end is reached. When input_end is reached, (input_end - input_cur) + is less than 4. When output_end is reached, (output_end - output_cur) is less + than 3. Returns false if decoding is failed. */ +bool grpc_base64_decode_partial(struct grpc_base64_decode_context *ctx); /* base64 decode a slice with pad chars. Returns a new slice, does not take - ownership of the input */ + ownership of the input. Returns an empty slice if decoding is failed. */ gpr_slice grpc_chttp2_base64_decode(gpr_slice input); /* base64 decode a slice without pad chars, data length is needed. Returns a new - slice, does not take ownership of the input */ + slice, does not take ownership of the input. Returns an empty slice if + decoding is failed. */ gpr_slice grpc_chttp2_base64_decode_with_length(gpr_slice input, size_t output_length); diff --git a/test/core/transport/chttp2/bin_decoder_test.c b/test/core/transport/chttp2/bin_decoder_test.c index 980da02dc35..c4e6cd332f0 100644 --- a/test/core/transport/chttp2/bin_decoder_test.c +++ b/test/core/transport/chttp2/bin_decoder_test.c @@ -37,7 +37,6 @@ #include #include -#include #include "src/core/ext/transport/chttp2/transport/bin_encoder.h" #include "src/core/lib/support/string.h" @@ -72,6 +71,14 @@ static gpr_slice base64_decode(const char *s) { return out; } +static gpr_slice base64_decode_with_length(const char *s, + size_t output_length) { + gpr_slice ss = gpr_slice_from_copied_string(s); + gpr_slice out = grpc_chttp2_base64_decode_with_length(ss, output_length); + gpr_slice_unref(ss); + return out; +} + #define EXPECT_SLICE_EQ(expected, slice) \ expect_slice_eq( \ gpr_slice_from_copied_buffer(expected, sizeof(expected) - 1), slice, \ @@ -82,11 +89,9 @@ static gpr_slice base64_decode(const char *s) { s, grpc_chttp2_base64_decode_with_length(base64_encode(s), strlen(s))); int main(int argc, char **argv) { - /* - * ENCODE_AND_DECODE tests grpc_chttp2_base64_decode_with_length(), which - * takes encoded base64 strings without pad chars, but output length is - * required - */ + /* ENCODE_AND_DECODE tests grpc_chttp2_base64_decode_with_length(), which + takes encoded base64 strings without pad chars, but output length is + required. */ /* Base64 test vectors from RFC 4648 */ ENCODE_AND_DECODE(""); ENCODE_AND_DECODE("f"); @@ -116,5 +121,24 @@ int main(int argc, char **argv) { EXPECT_SLICE_EQ("\xc0\xc1\xc2\xc3\xc4\xc5", base64_decode("wMHCw8TF")); + // Test illegal input length in grpc_chttp2_base64_decode + EXPECT_SLICE_EQ("", base64_decode("a")); + EXPECT_SLICE_EQ("", base64_decode("ab")); + EXPECT_SLICE_EQ("", base64_decode("abc")); + + // Test illegal charactors in grpc_chttp2_base64_decode + EXPECT_SLICE_EQ("", base64_decode("Zm:v")); + EXPECT_SLICE_EQ("", base64_decode("Zm=v")); + + // Test output_length longer than max possible output length in + // grpc_chttp2_base64_decode_with_length + EXPECT_SLICE_EQ("", base64_decode_with_length("Zg", 2)); + EXPECT_SLICE_EQ("", base64_decode_with_length("Zm8", 3)); + EXPECT_SLICE_EQ("", base64_decode_with_length("Zm9v", 4)); + + // Test illegal charactors in grpc_chttp2_base64_decode_with_length + EXPECT_SLICE_EQ("", base64_decode_with_length("Zm:v", 3)); + EXPECT_SLICE_EQ("", base64_decode_with_length("Zm=v", 3)); + return all_ok ? 0 : 1; }