diff --git a/BUILD b/BUILD index f502744ff9..1bdb9f461e 100644 --- a/BUILD +++ b/BUILD @@ -349,6 +349,12 @@ cc_test( ], ) +upb_proto_reflection_library( + name = "test_messages_proto3_proto_upb", + testonly = 1, + deps = ["@com_google_protobuf//:test_messages_proto3_proto"], +) + proto_library( name = "test_decoder_proto", srcs = [ @@ -516,8 +522,14 @@ upb_proto_library( deps = ["@com_google_protobuf//:conformance_proto"], ) -upb_proto_library( - name = "test_messages_proto3_proto_upb", +upb_proto_reflection_library( + name = "test_messages_proto2_upbdefs", + testonly = 1, + deps = ["@com_google_protobuf//:test_messages_proto2_proto"], +) + +upb_proto_reflection_library( + name = "test_messages_proto3_upbdefs", testonly = 1, deps = ["@com_google_protobuf//:test_messages_proto3_proto"], ) @@ -534,7 +546,9 @@ cc_binary( }) + ["-Ibazel-out/k8-fastbuild/bin"], deps = [ ":conformance_proto_upb", - ":test_messages_proto3_proto_upb", + ":test_messages_proto2_upbdefs", + ":test_messages_proto3_upbdefs", + ":reflection", ":upb", ], ) @@ -542,7 +556,7 @@ cc_binary( make_shell_script( name = "gen_test_conformance_upb", out = "test_conformance_upb.sh", - contents = "external/com_google_protobuf/conformance_test_runner ./conformance_upb", + contents = "external/com_google_protobuf/conformance_test_runner --enforce_recommended ./conformance_upb", ) sh_test( diff --git a/tests/conformance_upb.c b/tests/conformance_upb.c index 6063c9941e..ec47571197 100644 --- a/tests/conformance_upb.c +++ b/tests/conformance_upb.c @@ -9,7 +9,9 @@ #include #include "conformance/conformance.upb.h" -#include "src/google/protobuf/test_messages_proto3.upb.h" +#include "src/google/protobuf/test_messages_proto2.upbdefs.h" +#include "src/google/protobuf/test_messages_proto3.upbdefs.h" +#include "upb/reflection.h" int test_count = 0; @@ -39,138 +41,144 @@ void CheckedWrite(int fd, const void *buf, size_t len) { } } -bool strview_eql(upb_strview view, const char *str) { - return view.size == strlen(str) && memcmp(view.data, str, view.size) == 0; +typedef struct { + const conformance_ConformanceRequest *request; + conformance_ConformanceResponse *response; + upb_arena *arena; + const upb_symtab *symtab; +} ctx; + +bool parse_proto(upb_msg *msg, const upb_msgdef *m, const ctx* c) { + upb_strview proto = + conformance_ConformanceRequest_protobuf_payload(c->request); + if (upb_decode(proto.data, proto.size, msg, upb_msgdef_layout(m), c->arena)) { + return true; + } else { + static const char msg[] = "Parse error"; + conformance_ConformanceResponse_set_parse_error( + c->response, upb_strview_make(msg, strlen(msg))); + return false; + } } -static const char *proto3_msg = - "protobuf_test_messages.proto3.TestAllTypesProto3"; - -void DoTest( - const conformance_ConformanceRequest* request, - conformance_ConformanceResponse *response, - upb_arena *arena) { - protobuf_test_messages_proto3_TestAllTypesProto3 *test_message; - - if (!strview_eql(conformance_ConformanceRequest_message_type(request), - proto3_msg)) { - static const char msg[] = "Only proto3 for now."; - conformance_ConformanceResponse_set_skipped( - response, upb_strview_make(msg, sizeof(msg))); - return; +void serialize_proto(const upb_msg *msg, const upb_msgdef *m, const ctx *c) { + size_t len; + char *data = upb_encode(msg, upb_msgdef_layout(m), c->arena, &len); + if (data) { + conformance_ConformanceResponse_set_protobuf_payload( + c->response, upb_strview_make(data, len)); + } else { + static const char msg[] = "Error serializing."; + conformance_ConformanceResponse_set_serialize_error( + c->response, upb_strview_make(msg, strlen(msg))); } +} - switch (conformance_ConformanceRequest_payload_case(request)) { - case conformance_ConformanceRequest_payload_protobuf_payload: { - upb_strview payload = conformance_ConformanceRequest_protobuf_payload(request); - test_message = protobuf_test_messages_proto3_TestAllTypesProto3_parse( - payload.data, payload.size, arena); - - if (!test_message) { - static const char msg[] = "Parse error"; - conformance_ConformanceResponse_set_parse_error( - response, upb_strview_make(msg, sizeof(msg))); - return; - } - break; - } - +bool parse_input(upb_msg *msg, const upb_msgdef *m, const ctx* c) { + switch (conformance_ConformanceRequest_payload_case(c->request)) { + case conformance_ConformanceRequest_payload_protobuf_payload: + return parse_proto(msg, m, c); case conformance_ConformanceRequest_payload_NOT_SET: fprintf(stderr, "conformance_upb: Request didn't have payload.\n"); - return; - + return false; default: { static const char msg[] = "Unsupported input format."; conformance_ConformanceResponse_set_skipped( - response, upb_strview_make(msg, sizeof(msg))); - return; + c->response, upb_strview_make(msg, strlen(msg))); + return false; } } +} - switch (conformance_ConformanceRequest_requested_output_format(request)) { +void write_output(const upb_msg *msg, const upb_msgdef *m, const ctx* c) { + switch (conformance_ConformanceRequest_requested_output_format(c->request)) { case conformance_UNSPECIFIED: fprintf(stderr, "conformance_upb: Unspecified output format.\n"); exit(1); - - case conformance_PROTOBUF: { - size_t serialized_len; - char *serialized = - protobuf_test_messages_proto3_TestAllTypesProto3_serialize( - test_message, arena, &serialized_len); - if (!serialized) { - static const char msg[] = "Error serializing."; - conformance_ConformanceResponse_set_serialize_error( - response, upb_strview_make(msg, sizeof(msg))); - return; - } - conformance_ConformanceResponse_set_protobuf_payload( - response, upb_strview_make(serialized, serialized_len)); + case conformance_PROTOBUF: + serialize_proto(msg, m, c); break; - } - default: { static const char msg[] = "Unsupported output format."; conformance_ConformanceResponse_set_skipped( - response, upb_strview_make(msg, sizeof(msg))); - return; + c->response, upb_strview_make(msg, strlen(msg))); + break; } } +} + +void DoTest(const ctx* c) { + upb_msg *msg; + upb_strview name = conformance_ConformanceRequest_message_type(c->request); + const upb_msgdef *m = upb_symtab_lookupmsg2(c->symtab, name.data, name.size); + + if (!m) { + static const char msg[] = "Unknown message type."; + conformance_ConformanceResponse_set_skipped( + c->response, upb_strview_make(msg, strlen(msg))); + return; + } - return; + msg = upb_msg_new(m, c->arena); + + if (parse_input(msg, m, c)) { + write_output(msg, m, c); + } } -bool DoTestIo(void) { - upb_arena *arena; - upb_alloc *alloc; +bool DoTestIo(const upb_symtab *symtab) { upb_status status; - char *serialized_input; - char *serialized_output; + char *input; + char *output; uint32_t input_size; size_t output_size; - conformance_ConformanceRequest *request; - conformance_ConformanceResponse *response; + ctx c; if (!CheckedRead(STDIN_FILENO, &input_size, sizeof(uint32_t))) { /* EOF. */ return false; } - arena = upb_arena_new(); - alloc = upb_arena_alloc(arena); - serialized_input = upb_malloc(alloc, input_size); + c.symtab = symtab; + c.arena = upb_arena_new(); + input = upb_arena_malloc(c.arena, input_size); - if (!CheckedRead(STDIN_FILENO, serialized_input, input_size)) { + if (!CheckedRead(STDIN_FILENO, input, input_size)) { fprintf(stderr, "conformance_upb: unexpected EOF on stdin.\n"); exit(1); } - request = - conformance_ConformanceRequest_parse(serialized_input, input_size, arena); - response = conformance_ConformanceResponse_new(arena); + c.request = conformance_ConformanceRequest_parse(input, input_size, c.arena); + c.response = conformance_ConformanceResponse_new(c.arena); - if (request) { - DoTest(request, response, arena); + if (c.request) { + DoTest(&c); } else { fprintf(stderr, "conformance_upb: parse of ConformanceRequest failed: %s\n", upb_status_errmsg(&status)); } - serialized_output = conformance_ConformanceResponse_serialize( - response, arena, &output_size); + output = conformance_ConformanceResponse_serialize(c.response, c.arena, + &output_size); CheckedWrite(STDOUT_FILENO, &output_size, sizeof(uint32_t)); - CheckedWrite(STDOUT_FILENO, serialized_output, output_size); + CheckedWrite(STDOUT_FILENO, output, output_size); test_count++; - upb_arena_free(arena); + upb_arena_free(c.arena); return true; } int main(void) { + upb_symtab *symtab = upb_symtab_new(); + + protobuf_test_messages_proto2_TestAllTypesProto2_getmsgdef(symtab); + protobuf_test_messages_proto3_TestAllTypesProto3_getmsgdef(symtab); + while (1) { - if (!DoTestIo()) { + if (!DoTestIo(symtab)) { fprintf(stderr, "conformance_upb: received EOF from test runner " "after %d tests, exiting\n", test_count); return 0; diff --git a/upb/decode.c b/upb/decode.c index 9f1e986e73..ce5552af40 100644 --- a/upb/decode.c +++ b/upb/decode.c @@ -63,8 +63,6 @@ typedef struct { #define CHK(x) if (!(x)) { return 0; } #define PTR_AT(msg, ofs, type) (type*)((const char*)msg + ofs) -static const char *upb_skip_unknowngroup(const char *ptr, upb_decstate *d, - int field_number); static const char *upb_decode_message(const char *ptr, const upb_msglayout *l, upb_msg *msg, upb_decstate *d); @@ -132,14 +130,14 @@ static void upb_set32(void *msg, size_t ofs, uint32_t val) { memcpy((char*)msg + ofs, &val, sizeof(val)); } -static const char *upb_append_unknown(const char *ptr, upb_msg *msg, upb_decstate *d) { +static const char *upb_append_unknown(const char *ptr, upb_msg *msg, + upb_decstate *d) { upb_msg_addunknown(msg, d->field_start, ptr - d->field_start, d->arena); return ptr; } static const char *upb_skip_unknownfielddata(const char *ptr, upb_decstate *d, - uint32_t tag, - uint32_t group_fieldnum) { + uint32_t tag) { switch (tag & 7) { case UPB_WIRE_TYPE_VARINT: { uint64_t val; @@ -158,28 +156,24 @@ static const char *upb_skip_unknownfielddata(const char *ptr, upb_decstate *d, CHK(ptr = upb_decode_string(ptr, d->limit, &len)); return ptr + len; } - case UPB_WIRE_TYPE_START_GROUP: - return upb_skip_unknowngroup(ptr, d, tag >> 3); + case UPB_WIRE_TYPE_START_GROUP: { + uint32_t field_number = tag >> 3; + while (ptr < d->limit && d->end_group == 0) { + uint32_t tag = 0; + CHK(ptr = upb_decode_varint32(ptr, d->limit, &tag)); + CHK(ptr = upb_skip_unknownfielddata(ptr, d, tag)); + } + CHK(d->end_group == field_number); + d->end_group = 0; + return ptr; + } case UPB_WIRE_TYPE_END_GROUP: - CHK((tag >> 3) == group_fieldnum); + d->end_group = tag >> 3; return ptr; } return false; } -static const char *upb_skip_unknowngroup(const char *ptr, upb_decstate *d, - int field_number) { - while (ptr < d->limit && d->end_group == 0) { - uint32_t tag = 0; - CHK(ptr = upb_decode_varint32(ptr, d->limit, &tag)); - CHK(ptr = upb_skip_unknownfielddata(ptr, d, tag, field_number)); - } - - CHK(d->end_group == field_number); - d->end_group = 0; - return ptr; -} - static void *upb_array_reserve(upb_array *arr, size_t elements, size_t elem_size, upb_arena *arena) { if (arr->size - arr->len < elements) { @@ -604,7 +598,7 @@ static const char *upb_decode_field(const char *ptr, } } else { CHK(field_number != 0); - CHK(ptr = upb_skip_unknownfielddata(ptr, d, tag, -1)); + CHK(ptr = upb_skip_unknownfielddata(ptr, d, tag)); CHK(ptr = upb_append_unknown(ptr, msg, d)); return ptr; }