diff --git a/BUILD b/BUILD index 65cd4715f10..c8ad45a7887 100644 --- a/BUILD +++ b/BUILD @@ -2065,6 +2065,8 @@ grpc_cc_library( "include/grpcpp/impl/codegen/byte_buffer.h", "include/grpcpp/impl/codegen/call.h", "include/grpcpp/impl/codegen/call_hook.h", + "include/grpcpp/impl/codegen/call_op_set.h", + "include/grpcpp/impl/codegen/call_op_set_interface.h", "include/grpcpp/impl/codegen/callback_common.h", "include/grpcpp/impl/codegen/channel_interface.h", "include/grpcpp/impl/codegen/client_callback.h", @@ -2077,7 +2079,9 @@ grpc_cc_library( "include/grpcpp/impl/codegen/core_codegen_interface.h", "include/grpcpp/impl/codegen/create_auth_context.h", "include/grpcpp/impl/codegen/grpc_library.h", + "include/grpcpp/impl/codegen/intercepted_channel.h", "include/grpcpp/impl/codegen/interceptor.h", + "include/grpcpp/impl/codegen/interceptor_common.h", "include/grpcpp/impl/codegen/metadata_map.h", "include/grpcpp/impl/codegen/method_handler_impl.h", "include/grpcpp/impl/codegen/rpc_method.h", @@ -2085,6 +2089,7 @@ grpc_cc_library( "include/grpcpp/impl/codegen/security/auth_context.h", "include/grpcpp/impl/codegen/serialization_traits.h", "include/grpcpp/impl/codegen/server_context.h", + "include/grpcpp/impl/codegen/server_interceptor.h", "include/grpcpp/impl/codegen/server_interface.h", "include/grpcpp/impl/codegen/service_type.h", "include/grpcpp/impl/codegen/slice.h", diff --git a/CMakeLists.txt b/CMakeLists.txt index 23ef8b3b640..3b41ec6c6ed 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -584,6 +584,7 @@ if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) add_dependencies(buildtests_cxx client_crash_test) endif() add_dependencies(buildtests_cxx client_crash_test_server) +add_dependencies(buildtests_cxx client_interceptors_end2end_test) add_dependencies(buildtests_cxx client_lb_end2end_test) add_dependencies(buildtests_cxx codegen_test_full) add_dependencies(buildtests_cxx codegen_test_minimal) @@ -3112,6 +3113,8 @@ foreach(_hdr include/grpcpp/impl/codegen/byte_buffer.h include/grpcpp/impl/codegen/call.h include/grpcpp/impl/codegen/call_hook.h + include/grpcpp/impl/codegen/call_op_set.h + include/grpcpp/impl/codegen/call_op_set_interface.h include/grpcpp/impl/codegen/callback_common.h include/grpcpp/impl/codegen/channel_interface.h include/grpcpp/impl/codegen/client_callback.h @@ -3124,7 +3127,9 @@ foreach(_hdr include/grpcpp/impl/codegen/core_codegen_interface.h include/grpcpp/impl/codegen/create_auth_context.h include/grpcpp/impl/codegen/grpc_library.h + include/grpcpp/impl/codegen/intercepted_channel.h include/grpcpp/impl/codegen/interceptor.h + include/grpcpp/impl/codegen/interceptor_common.h include/grpcpp/impl/codegen/metadata_map.h include/grpcpp/impl/codegen/method_handler_impl.h include/grpcpp/impl/codegen/rpc_method.h @@ -3132,6 +3137,7 @@ foreach(_hdr include/grpcpp/impl/codegen/security/auth_context.h include/grpcpp/impl/codegen/serialization_traits.h include/grpcpp/impl/codegen/server_context.h + include/grpcpp/impl/codegen/server_interceptor.h include/grpcpp/impl/codegen/server_interface.h include/grpcpp/impl/codegen/service_type.h include/grpcpp/impl/codegen/slice.h @@ -3687,6 +3693,8 @@ foreach(_hdr include/grpcpp/impl/codegen/byte_buffer.h include/grpcpp/impl/codegen/call.h include/grpcpp/impl/codegen/call_hook.h + include/grpcpp/impl/codegen/call_op_set.h + include/grpcpp/impl/codegen/call_op_set_interface.h include/grpcpp/impl/codegen/callback_common.h include/grpcpp/impl/codegen/channel_interface.h include/grpcpp/impl/codegen/client_callback.h @@ -3699,7 +3707,9 @@ foreach(_hdr include/grpcpp/impl/codegen/core_codegen_interface.h include/grpcpp/impl/codegen/create_auth_context.h include/grpcpp/impl/codegen/grpc_library.h + include/grpcpp/impl/codegen/intercepted_channel.h include/grpcpp/impl/codegen/interceptor.h + include/grpcpp/impl/codegen/interceptor_common.h include/grpcpp/impl/codegen/metadata_map.h include/grpcpp/impl/codegen/method_handler_impl.h include/grpcpp/impl/codegen/rpc_method.h @@ -3707,6 +3717,7 @@ foreach(_hdr include/grpcpp/impl/codegen/security/auth_context.h include/grpcpp/impl/codegen/serialization_traits.h include/grpcpp/impl/codegen/server_context.h + include/grpcpp/impl/codegen/server_interceptor.h include/grpcpp/impl/codegen/server_interface.h include/grpcpp/impl/codegen/service_type.h include/grpcpp/impl/codegen/slice.h @@ -4096,6 +4107,8 @@ foreach(_hdr include/grpcpp/impl/codegen/byte_buffer.h include/grpcpp/impl/codegen/call.h include/grpcpp/impl/codegen/call_hook.h + include/grpcpp/impl/codegen/call_op_set.h + include/grpcpp/impl/codegen/call_op_set_interface.h include/grpcpp/impl/codegen/callback_common.h include/grpcpp/impl/codegen/channel_interface.h include/grpcpp/impl/codegen/client_callback.h @@ -4108,7 +4121,9 @@ foreach(_hdr include/grpcpp/impl/codegen/core_codegen_interface.h include/grpcpp/impl/codegen/create_auth_context.h include/grpcpp/impl/codegen/grpc_library.h + include/grpcpp/impl/codegen/intercepted_channel.h include/grpcpp/impl/codegen/interceptor.h + include/grpcpp/impl/codegen/interceptor_common.h include/grpcpp/impl/codegen/metadata_map.h include/grpcpp/impl/codegen/method_handler_impl.h include/grpcpp/impl/codegen/rpc_method.h @@ -4116,6 +4131,7 @@ foreach(_hdr include/grpcpp/impl/codegen/security/auth_context.h include/grpcpp/impl/codegen/serialization_traits.h include/grpcpp/impl/codegen/server_context.h + include/grpcpp/impl/codegen/server_interceptor.h include/grpcpp/impl/codegen/server_interface.h include/grpcpp/impl/codegen/service_type.h include/grpcpp/impl/codegen/slice.h @@ -4277,6 +4293,8 @@ foreach(_hdr include/grpcpp/impl/codegen/byte_buffer.h include/grpcpp/impl/codegen/call.h include/grpcpp/impl/codegen/call_hook.h + include/grpcpp/impl/codegen/call_op_set.h + include/grpcpp/impl/codegen/call_op_set_interface.h include/grpcpp/impl/codegen/callback_common.h include/grpcpp/impl/codegen/channel_interface.h include/grpcpp/impl/codegen/client_callback.h @@ -4289,7 +4307,9 @@ foreach(_hdr include/grpcpp/impl/codegen/core_codegen_interface.h include/grpcpp/impl/codegen/create_auth_context.h include/grpcpp/impl/codegen/grpc_library.h + include/grpcpp/impl/codegen/intercepted_channel.h include/grpcpp/impl/codegen/interceptor.h + include/grpcpp/impl/codegen/interceptor_common.h include/grpcpp/impl/codegen/metadata_map.h include/grpcpp/impl/codegen/method_handler_impl.h include/grpcpp/impl/codegen/rpc_method.h @@ -4297,6 +4317,7 @@ foreach(_hdr include/grpcpp/impl/codegen/security/auth_context.h include/grpcpp/impl/codegen/serialization_traits.h include/grpcpp/impl/codegen/server_context.h + include/grpcpp/impl/codegen/server_interceptor.h include/grpcpp/impl/codegen/server_interface.h include/grpcpp/impl/codegen/service_type.h include/grpcpp/impl/codegen/slice.h @@ -4602,6 +4623,8 @@ foreach(_hdr include/grpcpp/impl/codegen/byte_buffer.h include/grpcpp/impl/codegen/call.h include/grpcpp/impl/codegen/call_hook.h + include/grpcpp/impl/codegen/call_op_set.h + include/grpcpp/impl/codegen/call_op_set_interface.h include/grpcpp/impl/codegen/callback_common.h include/grpcpp/impl/codegen/channel_interface.h include/grpcpp/impl/codegen/client_callback.h @@ -4614,7 +4637,9 @@ foreach(_hdr include/grpcpp/impl/codegen/core_codegen_interface.h include/grpcpp/impl/codegen/create_auth_context.h include/grpcpp/impl/codegen/grpc_library.h + include/grpcpp/impl/codegen/intercepted_channel.h include/grpcpp/impl/codegen/interceptor.h + include/grpcpp/impl/codegen/interceptor_common.h include/grpcpp/impl/codegen/metadata_map.h include/grpcpp/impl/codegen/method_handler_impl.h include/grpcpp/impl/codegen/rpc_method.h @@ -4622,6 +4647,7 @@ foreach(_hdr include/grpcpp/impl/codegen/security/auth_context.h include/grpcpp/impl/codegen/serialization_traits.h include/grpcpp/impl/codegen/server_context.h + include/grpcpp/impl/codegen/server_interceptor.h include/grpcpp/impl/codegen/server_interface.h include/grpcpp/impl/codegen/service_type.h include/grpcpp/impl/codegen/slice.h @@ -12402,6 +12428,47 @@ target_link_libraries(client_crash_test_server ) +endif (gRPC_BUILD_TESTS) +if (gRPC_BUILD_TESTS) + +add_executable(client_interceptors_end2end_test + test/cpp/end2end/client_interceptors_end2end_test.cc + third_party/googletest/googletest/src/gtest-all.cc + third_party/googletest/googlemock/src/gmock-all.cc +) + + +target_include_directories(client_interceptors_end2end_test + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include + PRIVATE ${_gRPC_SSL_INCLUDE_DIR} + PRIVATE ${_gRPC_PROTOBUF_INCLUDE_DIR} + PRIVATE ${_gRPC_ZLIB_INCLUDE_DIR} + PRIVATE ${_gRPC_BENCHMARK_INCLUDE_DIR} + PRIVATE ${_gRPC_CARES_INCLUDE_DIR} + PRIVATE ${_gRPC_GFLAGS_INCLUDE_DIR} + PRIVATE ${_gRPC_ADDRESS_SORTING_INCLUDE_DIR} + PRIVATE ${_gRPC_NANOPB_INCLUDE_DIR} + PRIVATE third_party/googletest/googletest/include + PRIVATE third_party/googletest/googletest + PRIVATE third_party/googletest/googlemock/include + PRIVATE third_party/googletest/googlemock + PRIVATE ${_gRPC_PROTO_GENS_DIR} +) + +target_link_libraries(client_interceptors_end2end_test + ${_gRPC_PROTOBUF_LIBRARIES} + ${_gRPC_ALLTARGETS_LIBRARIES} + grpc++_test_util + grpc_test_util + grpc++ + grpc + gpr_test_util + gpr + ${_gRPC_GFLAGS_LIBRARIES} +) + + endif (gRPC_BUILD_TESTS) if (gRPC_BUILD_TESTS) diff --git a/Makefile b/Makefile index d0a3f863830..aafdebb6ca5 100644 --- a/Makefile +++ b/Makefile @@ -1164,6 +1164,7 @@ client_callback_end2end_test: $(BINDIR)/$(CONFIG)/client_callback_end2end_test client_channel_stress_test: $(BINDIR)/$(CONFIG)/client_channel_stress_test client_crash_test: $(BINDIR)/$(CONFIG)/client_crash_test client_crash_test_server: $(BINDIR)/$(CONFIG)/client_crash_test_server +client_interceptors_end2end_test: $(BINDIR)/$(CONFIG)/client_interceptors_end2end_test client_lb_end2end_test: $(BINDIR)/$(CONFIG)/client_lb_end2end_test codegen_test_full: $(BINDIR)/$(CONFIG)/codegen_test_full codegen_test_minimal: $(BINDIR)/$(CONFIG)/codegen_test_minimal @@ -1667,6 +1668,7 @@ buildtests_cxx: privatelibs_cxx \ $(BINDIR)/$(CONFIG)/client_channel_stress_test \ $(BINDIR)/$(CONFIG)/client_crash_test \ $(BINDIR)/$(CONFIG)/client_crash_test_server \ + $(BINDIR)/$(CONFIG)/client_interceptors_end2end_test \ $(BINDIR)/$(CONFIG)/client_lb_end2end_test \ $(BINDIR)/$(CONFIG)/codegen_test_full \ $(BINDIR)/$(CONFIG)/codegen_test_minimal \ @@ -1848,6 +1850,7 @@ buildtests_cxx: privatelibs_cxx \ $(BINDIR)/$(CONFIG)/client_channel_stress_test \ $(BINDIR)/$(CONFIG)/client_crash_test \ $(BINDIR)/$(CONFIG)/client_crash_test_server \ + $(BINDIR)/$(CONFIG)/client_interceptors_end2end_test \ $(BINDIR)/$(CONFIG)/client_lb_end2end_test \ $(BINDIR)/$(CONFIG)/codegen_test_full \ $(BINDIR)/$(CONFIG)/codegen_test_minimal \ @@ -2306,6 +2309,8 @@ test_cxx: buildtests_cxx $(Q) $(BINDIR)/$(CONFIG)/client_channel_stress_test || ( echo test client_channel_stress_test failed ; exit 1 ) $(E) "[RUN] Testing client_crash_test" $(Q) $(BINDIR)/$(CONFIG)/client_crash_test || ( echo test client_crash_test failed ; exit 1 ) + $(E) "[RUN] Testing client_interceptors_end2end_test" + $(Q) $(BINDIR)/$(CONFIG)/client_interceptors_end2end_test || ( echo test client_interceptors_end2end_test failed ; exit 1 ) $(E) "[RUN] Testing client_lb_end2end_test" $(Q) $(BINDIR)/$(CONFIG)/client_lb_end2end_test || ( echo test client_lb_end2end_test failed ; exit 1 ) $(E) "[RUN] Testing codegen_test_full" @@ -5455,6 +5460,8 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/impl/codegen/byte_buffer.h \ include/grpcpp/impl/codegen/call.h \ include/grpcpp/impl/codegen/call_hook.h \ + include/grpcpp/impl/codegen/call_op_set.h \ + include/grpcpp/impl/codegen/call_op_set_interface.h \ include/grpcpp/impl/codegen/callback_common.h \ include/grpcpp/impl/codegen/channel_interface.h \ include/grpcpp/impl/codegen/client_callback.h \ @@ -5467,7 +5474,9 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/impl/codegen/core_codegen_interface.h \ include/grpcpp/impl/codegen/create_auth_context.h \ include/grpcpp/impl/codegen/grpc_library.h \ + include/grpcpp/impl/codegen/intercepted_channel.h \ include/grpcpp/impl/codegen/interceptor.h \ + include/grpcpp/impl/codegen/interceptor_common.h \ include/grpcpp/impl/codegen/metadata_map.h \ include/grpcpp/impl/codegen/method_handler_impl.h \ include/grpcpp/impl/codegen/rpc_method.h \ @@ -5475,6 +5484,7 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/impl/codegen/security/auth_context.h \ include/grpcpp/impl/codegen/serialization_traits.h \ include/grpcpp/impl/codegen/server_context.h \ + include/grpcpp/impl/codegen/server_interceptor.h \ include/grpcpp/impl/codegen/server_interface.h \ include/grpcpp/impl/codegen/service_type.h \ include/grpcpp/impl/codegen/slice.h \ @@ -6039,6 +6049,8 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/impl/codegen/byte_buffer.h \ include/grpcpp/impl/codegen/call.h \ include/grpcpp/impl/codegen/call_hook.h \ + include/grpcpp/impl/codegen/call_op_set.h \ + include/grpcpp/impl/codegen/call_op_set_interface.h \ include/grpcpp/impl/codegen/callback_common.h \ include/grpcpp/impl/codegen/channel_interface.h \ include/grpcpp/impl/codegen/client_callback.h \ @@ -6051,7 +6063,9 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/impl/codegen/core_codegen_interface.h \ include/grpcpp/impl/codegen/create_auth_context.h \ include/grpcpp/impl/codegen/grpc_library.h \ + include/grpcpp/impl/codegen/intercepted_channel.h \ include/grpcpp/impl/codegen/interceptor.h \ + include/grpcpp/impl/codegen/interceptor_common.h \ include/grpcpp/impl/codegen/metadata_map.h \ include/grpcpp/impl/codegen/method_handler_impl.h \ include/grpcpp/impl/codegen/rpc_method.h \ @@ -6059,6 +6073,7 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/impl/codegen/security/auth_context.h \ include/grpcpp/impl/codegen/serialization_traits.h \ include/grpcpp/impl/codegen/server_context.h \ + include/grpcpp/impl/codegen/server_interceptor.h \ include/grpcpp/impl/codegen/server_interface.h \ include/grpcpp/impl/codegen/service_type.h \ include/grpcpp/impl/codegen/slice.h \ @@ -6433,6 +6448,8 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/impl/codegen/byte_buffer.h \ include/grpcpp/impl/codegen/call.h \ include/grpcpp/impl/codegen/call_hook.h \ + include/grpcpp/impl/codegen/call_op_set.h \ + include/grpcpp/impl/codegen/call_op_set_interface.h \ include/grpcpp/impl/codegen/callback_common.h \ include/grpcpp/impl/codegen/channel_interface.h \ include/grpcpp/impl/codegen/client_callback.h \ @@ -6445,7 +6462,9 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/impl/codegen/core_codegen_interface.h \ include/grpcpp/impl/codegen/create_auth_context.h \ include/grpcpp/impl/codegen/grpc_library.h \ + include/grpcpp/impl/codegen/intercepted_channel.h \ include/grpcpp/impl/codegen/interceptor.h \ + include/grpcpp/impl/codegen/interceptor_common.h \ include/grpcpp/impl/codegen/metadata_map.h \ include/grpcpp/impl/codegen/method_handler_impl.h \ include/grpcpp/impl/codegen/rpc_method.h \ @@ -6453,6 +6472,7 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/impl/codegen/security/auth_context.h \ include/grpcpp/impl/codegen/serialization_traits.h \ include/grpcpp/impl/codegen/server_context.h \ + include/grpcpp/impl/codegen/server_interceptor.h \ include/grpcpp/impl/codegen/server_interface.h \ include/grpcpp/impl/codegen/service_type.h \ include/grpcpp/impl/codegen/slice.h \ @@ -6591,6 +6611,8 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/impl/codegen/byte_buffer.h \ include/grpcpp/impl/codegen/call.h \ include/grpcpp/impl/codegen/call_hook.h \ + include/grpcpp/impl/codegen/call_op_set.h \ + include/grpcpp/impl/codegen/call_op_set_interface.h \ include/grpcpp/impl/codegen/callback_common.h \ include/grpcpp/impl/codegen/channel_interface.h \ include/grpcpp/impl/codegen/client_callback.h \ @@ -6603,7 +6625,9 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/impl/codegen/core_codegen_interface.h \ include/grpcpp/impl/codegen/create_auth_context.h \ include/grpcpp/impl/codegen/grpc_library.h \ + include/grpcpp/impl/codegen/intercepted_channel.h \ include/grpcpp/impl/codegen/interceptor.h \ + include/grpcpp/impl/codegen/interceptor_common.h \ include/grpcpp/impl/codegen/metadata_map.h \ include/grpcpp/impl/codegen/method_handler_impl.h \ include/grpcpp/impl/codegen/rpc_method.h \ @@ -6611,6 +6635,7 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/impl/codegen/security/auth_context.h \ include/grpcpp/impl/codegen/serialization_traits.h \ include/grpcpp/impl/codegen/server_context.h \ + include/grpcpp/impl/codegen/server_interceptor.h \ include/grpcpp/impl/codegen/server_interface.h \ include/grpcpp/impl/codegen/service_type.h \ include/grpcpp/impl/codegen/slice.h \ @@ -6921,6 +6946,8 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/impl/codegen/byte_buffer.h \ include/grpcpp/impl/codegen/call.h \ include/grpcpp/impl/codegen/call_hook.h \ + include/grpcpp/impl/codegen/call_op_set.h \ + include/grpcpp/impl/codegen/call_op_set_interface.h \ include/grpcpp/impl/codegen/callback_common.h \ include/grpcpp/impl/codegen/channel_interface.h \ include/grpcpp/impl/codegen/client_callback.h \ @@ -6933,7 +6960,9 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/impl/codegen/core_codegen_interface.h \ include/grpcpp/impl/codegen/create_auth_context.h \ include/grpcpp/impl/codegen/grpc_library.h \ + include/grpcpp/impl/codegen/intercepted_channel.h \ include/grpcpp/impl/codegen/interceptor.h \ + include/grpcpp/impl/codegen/interceptor_common.h \ include/grpcpp/impl/codegen/metadata_map.h \ include/grpcpp/impl/codegen/method_handler_impl.h \ include/grpcpp/impl/codegen/rpc_method.h \ @@ -6941,6 +6970,7 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/impl/codegen/security/auth_context.h \ include/grpcpp/impl/codegen/serialization_traits.h \ include/grpcpp/impl/codegen/server_context.h \ + include/grpcpp/impl/codegen/server_interceptor.h \ include/grpcpp/impl/codegen/server_interface.h \ include/grpcpp/impl/codegen/service_type.h \ include/grpcpp/impl/codegen/slice.h \ @@ -17279,6 +17309,49 @@ endif endif +CLIENT_INTERCEPTORS_END2END_TEST_SRC = \ + test/cpp/end2end/client_interceptors_end2end_test.cc \ + +CLIENT_INTERCEPTORS_END2END_TEST_OBJS = $(addprefix $(OBJDIR)/$(CONFIG)/, $(addsuffix .o, $(basename $(CLIENT_INTERCEPTORS_END2END_TEST_SRC)))) +ifeq ($(NO_SECURE),true) + +# You can't build secure targets if you don't have OpenSSL. + +$(BINDIR)/$(CONFIG)/client_interceptors_end2end_test: openssl_dep_error + +else + + + + +ifeq ($(NO_PROTOBUF),true) + +# You can't build the protoc plugins or protobuf-enabled targets if you don't have protobuf 3.5.0+. + +$(BINDIR)/$(CONFIG)/client_interceptors_end2end_test: protobuf_dep_error + +else + +$(BINDIR)/$(CONFIG)/client_interceptors_end2end_test: $(PROTOBUF_DEP) $(CLIENT_INTERCEPTORS_END2END_TEST_OBJS) $(LIBDIR)/$(CONFIG)/libgrpc++_test_util.a $(LIBDIR)/$(CONFIG)/libgrpc_test_util.a $(LIBDIR)/$(CONFIG)/libgrpc++.a $(LIBDIR)/$(CONFIG)/libgrpc.a $(LIBDIR)/$(CONFIG)/libgpr_test_util.a $(LIBDIR)/$(CONFIG)/libgpr.a + $(E) "[LD] Linking $@" + $(Q) mkdir -p `dirname $@` + $(Q) $(LDXX) $(LDFLAGS) $(CLIENT_INTERCEPTORS_END2END_TEST_OBJS) $(LIBDIR)/$(CONFIG)/libgrpc++_test_util.a $(LIBDIR)/$(CONFIG)/libgrpc_test_util.a $(LIBDIR)/$(CONFIG)/libgrpc++.a $(LIBDIR)/$(CONFIG)/libgrpc.a $(LIBDIR)/$(CONFIG)/libgpr_test_util.a $(LIBDIR)/$(CONFIG)/libgpr.a $(LDLIBSXX) $(LDLIBS_PROTOBUF) $(LDLIBS) $(LDLIBS_SECURE) $(GTEST_LIB) -o $(BINDIR)/$(CONFIG)/client_interceptors_end2end_test + +endif + +endif + +$(OBJDIR)/$(CONFIG)/test/cpp/end2end/client_interceptors_end2end_test.o: $(LIBDIR)/$(CONFIG)/libgrpc++_test_util.a $(LIBDIR)/$(CONFIG)/libgrpc_test_util.a $(LIBDIR)/$(CONFIG)/libgrpc++.a $(LIBDIR)/$(CONFIG)/libgrpc.a $(LIBDIR)/$(CONFIG)/libgpr_test_util.a $(LIBDIR)/$(CONFIG)/libgpr.a + +deps_client_interceptors_end2end_test: $(CLIENT_INTERCEPTORS_END2END_TEST_OBJS:.o=.dep) + +ifneq ($(NO_SECURE),true) +ifneq ($(NO_DEPS),true) +-include $(CLIENT_INTERCEPTORS_END2END_TEST_OBJS:.o=.dep) +endif +endif + + CLIENT_LB_END2END_TEST_SRC = \ test/cpp/end2end/client_lb_end2end_test.cc \ diff --git a/build.yaml b/build.yaml index d8d51c1b94c..e5955007d4d 100644 --- a/build.yaml +++ b/build.yaml @@ -1222,6 +1222,8 @@ filegroups: - include/grpcpp/impl/codegen/byte_buffer.h - include/grpcpp/impl/codegen/call.h - include/grpcpp/impl/codegen/call_hook.h + - include/grpcpp/impl/codegen/call_op_set.h + - include/grpcpp/impl/codegen/call_op_set_interface.h - include/grpcpp/impl/codegen/callback_common.h - include/grpcpp/impl/codegen/channel_interface.h - include/grpcpp/impl/codegen/client_callback.h @@ -1234,7 +1236,9 @@ filegroups: - include/grpcpp/impl/codegen/core_codegen_interface.h - include/grpcpp/impl/codegen/create_auth_context.h - include/grpcpp/impl/codegen/grpc_library.h + - include/grpcpp/impl/codegen/intercepted_channel.h - include/grpcpp/impl/codegen/interceptor.h + - include/grpcpp/impl/codegen/interceptor_common.h - include/grpcpp/impl/codegen/metadata_map.h - include/grpcpp/impl/codegen/method_handler_impl.h - include/grpcpp/impl/codegen/rpc_method.h @@ -1242,6 +1246,7 @@ filegroups: - include/grpcpp/impl/codegen/security/auth_context.h - include/grpcpp/impl/codegen/serialization_traits.h - include/grpcpp/impl/codegen/server_context.h + - include/grpcpp/impl/codegen/server_interceptor.h - include/grpcpp/impl/codegen/server_interface.h - include/grpcpp/impl/codegen/service_type.h - include/grpcpp/impl/codegen/slice.h @@ -4523,6 +4528,22 @@ targets: - grpc - gpr_test_util - gpr +- name: client_interceptors_end2end_test + gtest: true + cpu_cost: 0.5 + build: test + language: c++ + headers: + - test/cpp/end2end/interceptors_util.h + src: + - test/cpp/end2end/client_interceptors_end2end_test.cc + deps: + - grpc++_test_util + - grpc_test_util + - grpc++ + - grpc + - gpr_test_util + - gpr - name: client_lb_end2end_test gtest: true build: test diff --git a/gRPC-C++.podspec b/gRPC-C++.podspec index d15a00e5f72..485646171c7 100644 --- a/gRPC-C++.podspec +++ b/gRPC-C++.podspec @@ -131,6 +131,8 @@ Pod::Spec.new do |s| 'include/grpcpp/impl/codegen/byte_buffer.h', 'include/grpcpp/impl/codegen/call.h', 'include/grpcpp/impl/codegen/call_hook.h', + 'include/grpcpp/impl/codegen/call_op_set.h', + 'include/grpcpp/impl/codegen/call_op_set_interface.h', 'include/grpcpp/impl/codegen/callback_common.h', 'include/grpcpp/impl/codegen/channel_interface.h', 'include/grpcpp/impl/codegen/client_callback.h', @@ -143,7 +145,9 @@ Pod::Spec.new do |s| 'include/grpcpp/impl/codegen/core_codegen_interface.h', 'include/grpcpp/impl/codegen/create_auth_context.h', 'include/grpcpp/impl/codegen/grpc_library.h', + 'include/grpcpp/impl/codegen/intercepted_channel.h', 'include/grpcpp/impl/codegen/interceptor.h', + 'include/grpcpp/impl/codegen/interceptor_common.h', 'include/grpcpp/impl/codegen/metadata_map.h', 'include/grpcpp/impl/codegen/method_handler_impl.h', 'include/grpcpp/impl/codegen/rpc_method.h', @@ -151,6 +155,7 @@ Pod::Spec.new do |s| 'include/grpcpp/impl/codegen/security/auth_context.h', 'include/grpcpp/impl/codegen/serialization_traits.h', 'include/grpcpp/impl/codegen/server_context.h', + 'include/grpcpp/impl/codegen/server_interceptor.h', 'include/grpcpp/impl/codegen/server_interface.h', 'include/grpcpp/impl/codegen/service_type.h', 'include/grpcpp/impl/codegen/slice.h', diff --git a/include/grpcpp/channel.h b/include/grpcpp/channel.h index b7c9e354de1..14209b85eeb 100644 --- a/include/grpcpp/channel.h +++ b/include/grpcpp/channel.h @@ -20,6 +20,7 @@ #define GRPCPP_CHANNEL_H #include +#include #include #include @@ -67,6 +68,7 @@ class Channel final : public ChannelInterface, std::unique_ptr>> interceptor_creators); + friend class internal::InterceptedChannel; Channel(const grpc::string& host, grpc_channel* c_channel, std::unique_ptr>> @@ -87,6 +89,10 @@ class Channel final : public ChannelInterface, CompletionQueue* CallbackCQ() override; + internal::Call CreateCallInternal(const internal::RpcMethod& method, + ClientContext* context, CompletionQueue* cq, + int interceptor_pos) override; + const grpc::string host_; grpc_channel* const c_channel_; // owned diff --git a/include/grpcpp/impl/codegen/async_stream.h b/include/grpcpp/impl/codegen/async_stream.h index 6e58fd0eefd..bfb2df4f232 100644 --- a/include/grpcpp/impl/codegen/async_stream.h +++ b/include/grpcpp/impl/codegen/async_stream.h @@ -276,7 +276,7 @@ class ClientAsyncReader final : public ClientAsyncReaderInterface { } void StartCallInternal(void* tag) { - init_ops_.SendInitialMetadata(context_->send_initial_metadata_, + init_ops_.SendInitialMetadata(&context_->send_initial_metadata_, context_->initial_metadata_flags()); init_ops_.set_output_tag(tag); call_.PerformOps(&init_ops_); @@ -441,7 +441,7 @@ class ClientAsyncWriter final : public ClientAsyncWriterInterface { } void StartCallInternal(void* tag) { - write_ops_.SendInitialMetadata(context_->send_initial_metadata_, + write_ops_.SendInitialMetadata(&context_->send_initial_metadata_, context_->initial_metadata_flags()); // if corked bit is set in context, we just keep the initial metadata // buffered up to coalesce with later message send. No op is performed. @@ -612,7 +612,7 @@ class ClientAsyncReaderWriter final } void StartCallInternal(void* tag) { - write_ops_.SendInitialMetadata(context_->send_initial_metadata_, + write_ops_.SendInitialMetadata(&context_->send_initial_metadata_, context_->initial_metadata_flags()); // if corked bit is set in context, we just keep the initial metadata // buffered up to coalesce with later message send. No op is performed. @@ -710,7 +710,7 @@ class ServerAsyncReader final : public ServerAsyncReaderInterface { GPR_CODEGEN_ASSERT(!ctx_->sent_initial_metadata_); meta_ops_.set_output_tag(tag); - meta_ops_.SendInitialMetadata(ctx_->initial_metadata_, + meta_ops_.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { meta_ops_.set_compression_level(ctx_->compression_level()); @@ -739,7 +739,7 @@ class ServerAsyncReader final : public ServerAsyncReaderInterface { void Finish(const W& msg, const Status& status, void* tag) override { finish_ops_.set_output_tag(tag); if (!ctx_->sent_initial_metadata_) { - finish_ops_.SendInitialMetadata(ctx_->initial_metadata_, + finish_ops_.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { finish_ops_.set_compression_level(ctx_->compression_level()); @@ -748,10 +748,10 @@ class ServerAsyncReader final : public ServerAsyncReaderInterface { } // The response is dropped if the status is not OK. if (status.ok()) { - finish_ops_.ServerSendStatus(ctx_->trailing_metadata_, + finish_ops_.ServerSendStatus(&ctx_->trailing_metadata_, finish_ops_.SendMessage(msg)); } else { - finish_ops_.ServerSendStatus(ctx_->trailing_metadata_, status); + finish_ops_.ServerSendStatus(&ctx_->trailing_metadata_, status); } call_.PerformOps(&finish_ops_); } @@ -769,14 +769,14 @@ class ServerAsyncReader final : public ServerAsyncReaderInterface { GPR_CODEGEN_ASSERT(!status.ok()); finish_ops_.set_output_tag(tag); if (!ctx_->sent_initial_metadata_) { - finish_ops_.SendInitialMetadata(ctx_->initial_metadata_, + finish_ops_.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { finish_ops_.set_compression_level(ctx_->compression_level()); } ctx_->sent_initial_metadata_ = true; } - finish_ops_.ServerSendStatus(ctx_->trailing_metadata_, status); + finish_ops_.ServerSendStatus(&ctx_->trailing_metadata_, status); call_.PerformOps(&finish_ops_); } @@ -859,7 +859,7 @@ class ServerAsyncWriter final : public ServerAsyncWriterInterface { GPR_CODEGEN_ASSERT(!ctx_->sent_initial_metadata_); meta_ops_.set_output_tag(tag); - meta_ops_.SendInitialMetadata(ctx_->initial_metadata_, + meta_ops_.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { meta_ops_.set_compression_level(ctx_->compression_level()); @@ -904,7 +904,7 @@ class ServerAsyncWriter final : public ServerAsyncWriterInterface { EnsureInitialMetadataSent(&write_ops_); options.set_buffer_hint(); GPR_CODEGEN_ASSERT(write_ops_.SendMessage(msg, options).ok()); - write_ops_.ServerSendStatus(ctx_->trailing_metadata_, status); + write_ops_.ServerSendStatus(&ctx_->trailing_metadata_, status); call_.PerformOps(&write_ops_); } @@ -922,7 +922,7 @@ class ServerAsyncWriter final : public ServerAsyncWriterInterface { void Finish(const Status& status, void* tag) override { finish_ops_.set_output_tag(tag); EnsureInitialMetadataSent(&finish_ops_); - finish_ops_.ServerSendStatus(ctx_->trailing_metadata_, status); + finish_ops_.ServerSendStatus(&ctx_->trailing_metadata_, status); call_.PerformOps(&finish_ops_); } @@ -932,7 +932,7 @@ class ServerAsyncWriter final : public ServerAsyncWriterInterface { template void EnsureInitialMetadataSent(T* ops) { if (!ctx_->sent_initial_metadata_) { - ops->SendInitialMetadata(ctx_->initial_metadata_, + ops->SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { ops->set_compression_level(ctx_->compression_level()); @@ -1025,7 +1025,7 @@ class ServerAsyncReaderWriter final GPR_CODEGEN_ASSERT(!ctx_->sent_initial_metadata_); meta_ops_.set_output_tag(tag); - meta_ops_.SendInitialMetadata(ctx_->initial_metadata_, + meta_ops_.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { meta_ops_.set_compression_level(ctx_->compression_level()); @@ -1075,7 +1075,7 @@ class ServerAsyncReaderWriter final EnsureInitialMetadataSent(&write_ops_); options.set_buffer_hint(); GPR_CODEGEN_ASSERT(write_ops_.SendMessage(msg, options).ok()); - write_ops_.ServerSendStatus(ctx_->trailing_metadata_, status); + write_ops_.ServerSendStatus(&ctx_->trailing_metadata_, status); call_.PerformOps(&write_ops_); } @@ -1094,7 +1094,7 @@ class ServerAsyncReaderWriter final finish_ops_.set_output_tag(tag); EnsureInitialMetadataSent(&finish_ops_); - finish_ops_.ServerSendStatus(ctx_->trailing_metadata_, status); + finish_ops_.ServerSendStatus(&ctx_->trailing_metadata_, status); call_.PerformOps(&finish_ops_); } @@ -1106,7 +1106,7 @@ class ServerAsyncReaderWriter final template void EnsureInitialMetadataSent(T* ops) { if (!ctx_->sent_initial_metadata_) { - ops->SendInitialMetadata(ctx_->initial_metadata_, + ops->SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { ops->set_compression_level(ctx_->compression_level()); diff --git a/include/grpcpp/impl/codegen/async_unary_call.h b/include/grpcpp/impl/codegen/async_unary_call.h index 60ff8e2f054..744b1281411 100644 --- a/include/grpcpp/impl/codegen/async_unary_call.h +++ b/include/grpcpp/impl/codegen/async_unary_call.h @@ -174,7 +174,7 @@ class ClientAsyncResponseReader final } void StartCallInternal() { - single_buf.SendInitialMetadata(context_->send_initial_metadata_, + single_buf.SendInitialMetadata(&context_->send_initial_metadata_, context_->initial_metadata_flags()); } @@ -214,7 +214,7 @@ class ServerAsyncResponseWriter final GPR_CODEGEN_ASSERT(!ctx_->sent_initial_metadata_); meta_buf_.set_output_tag(tag); - meta_buf_.SendInitialMetadata(ctx_->initial_metadata_, + meta_buf_.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { meta_buf_.set_compression_level(ctx_->compression_level()); @@ -240,8 +240,9 @@ class ServerAsyncResponseWriter final /// metadata. void Finish(const W& msg, const Status& status, void* tag) { finish_buf_.set_output_tag(tag); + finish_buf_.set_cq_tag(&finish_buf_); if (!ctx_->sent_initial_metadata_) { - finish_buf_.SendInitialMetadata(ctx_->initial_metadata_, + finish_buf_.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { finish_buf_.set_compression_level(ctx_->compression_level()); @@ -250,10 +251,10 @@ class ServerAsyncResponseWriter final } // The response is dropped if the status is not OK. if (status.ok()) { - finish_buf_.ServerSendStatus(ctx_->trailing_metadata_, + finish_buf_.ServerSendStatus(&ctx_->trailing_metadata_, finish_buf_.SendMessage(msg)); } else { - finish_buf_.ServerSendStatus(ctx_->trailing_metadata_, status); + finish_buf_.ServerSendStatus(&ctx_->trailing_metadata_, status); } call_.PerformOps(&finish_buf_); } @@ -274,14 +275,14 @@ class ServerAsyncResponseWriter final GPR_CODEGEN_ASSERT(!status.ok()); finish_buf_.set_output_tag(tag); if (!ctx_->sent_initial_metadata_) { - finish_buf_.SendInitialMetadata(ctx_->initial_metadata_, + finish_buf_.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { finish_buf_.set_compression_level(ctx_->compression_level()); } ctx_->sent_initial_metadata_ = true; } - finish_buf_.ServerSendStatus(ctx_->trailing_metadata_, status); + finish_buf_.ServerSendStatus(&ctx_->trailing_metadata_, status); call_.PerformOps(&finish_buf_); } diff --git a/include/grpcpp/impl/codegen/byte_buffer.h b/include/grpcpp/impl/codegen/byte_buffer.h index 8cc51581152..d54ae31852e 100644 --- a/include/grpcpp/impl/codegen/byte_buffer.h +++ b/include/grpcpp/impl/codegen/byte_buffer.h @@ -50,6 +50,11 @@ class ErrorMethodHandler; template class DeserializeFuncType; class GrpcByteBufferPeer; +template +class RpcMethodHandler; +template +class ServerStreamingHandler; + } // namespace internal /// A sequence of bytes. class ByteBuffer final { @@ -141,7 +146,10 @@ class ByteBuffer final { template friend class internal::CallOpRecvMessage; friend class internal::CallOpGenericRecvMessage; - friend class internal::MethodHandler; + template + friend class RpcMethodHandler; + template + friend class ServerStreamingHandler; template friend class internal::RpcMethodHandler; template diff --git a/include/grpcpp/impl/codegen/call.h b/include/grpcpp/impl/codegen/call.h index 789ea805a3c..c040c30dd9d 100644 --- a/include/grpcpp/impl/codegen/call.h +++ b/include/grpcpp/impl/codegen/call.h @@ -1,6 +1,6 @@ /* * - * Copyright 2015 gRPC authors. + * Copyright 2018 gRPC authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,671 +15,31 @@ * limitations under the License. * */ - #ifndef GRPCPP_IMPL_CODEGEN_CALL_H #define GRPCPP_IMPL_CODEGEN_CALL_H -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include #include +#include namespace grpc { - -class ByteBuffer; class CompletionQueue; -extern CoreCodegenInterface* g_core_codegen_interface; +namespace experimental { +class ClientRpcInfo; +class ServerRpcInfo; +} // namespace experimental namespace internal { -class Call; class CallHook; - -// TODO(yangg) if the map is changed before we send, the pointers will be a -// mess. Make sure it does not happen. -inline grpc_metadata* FillMetadataArray( - const std::multimap& metadata, - size_t* metadata_count, const grpc::string& optional_error_details) { - *metadata_count = metadata.size() + (optional_error_details.empty() ? 0 : 1); - if (*metadata_count == 0) { - return nullptr; - } - grpc_metadata* metadata_array = - (grpc_metadata*)(g_core_codegen_interface->gpr_malloc( - (*metadata_count) * sizeof(grpc_metadata))); - size_t i = 0; - for (auto iter = metadata.cbegin(); iter != metadata.cend(); ++iter, ++i) { - metadata_array[i].key = SliceReferencingString(iter->first); - metadata_array[i].value = SliceReferencingString(iter->second); - } - if (!optional_error_details.empty()) { - metadata_array[i].key = - g_core_codegen_interface->grpc_slice_from_static_buffer( - kBinaryErrorDetailsKey, sizeof(kBinaryErrorDetailsKey) - 1); - metadata_array[i].value = SliceReferencingString(optional_error_details); - } - return metadata_array; -} -} // namespace internal - -/// Per-message write options. -class WriteOptions { - public: - WriteOptions() : flags_(0), last_message_(false) {} - WriteOptions(const WriteOptions& other) - : flags_(other.flags_), last_message_(other.last_message_) {} - - /// Clear all flags. - inline void Clear() { flags_ = 0; } - - /// Returns raw flags bitset. - inline uint32_t flags() const { return flags_; } - - /// Sets flag for the disabling of compression for the next message write. - /// - /// \sa GRPC_WRITE_NO_COMPRESS - inline WriteOptions& set_no_compression() { - SetBit(GRPC_WRITE_NO_COMPRESS); - return *this; - } - - /// Clears flag for the disabling of compression for the next message write. - /// - /// \sa GRPC_WRITE_NO_COMPRESS - inline WriteOptions& clear_no_compression() { - ClearBit(GRPC_WRITE_NO_COMPRESS); - return *this; - } - - /// Get value for the flag indicating whether compression for the next - /// message write is forcefully disabled. - /// - /// \sa GRPC_WRITE_NO_COMPRESS - inline bool get_no_compression() const { - return GetBit(GRPC_WRITE_NO_COMPRESS); - } - - /// Sets flag indicating that the write may be buffered and need not go out on - /// the wire immediately. - /// - /// \sa GRPC_WRITE_BUFFER_HINT - inline WriteOptions& set_buffer_hint() { - SetBit(GRPC_WRITE_BUFFER_HINT); - return *this; - } - - /// Clears flag indicating that the write may be buffered and need not go out - /// on the wire immediately. - /// - /// \sa GRPC_WRITE_BUFFER_HINT - inline WriteOptions& clear_buffer_hint() { - ClearBit(GRPC_WRITE_BUFFER_HINT); - return *this; - } - - /// Get value for the flag indicating that the write may be buffered and need - /// not go out on the wire immediately. - /// - /// \sa GRPC_WRITE_BUFFER_HINT - inline bool get_buffer_hint() const { return GetBit(GRPC_WRITE_BUFFER_HINT); } - - /// corked bit: aliases set_buffer_hint currently, with the intent that - /// set_buffer_hint will be removed in the future - inline WriteOptions& set_corked() { - SetBit(GRPC_WRITE_BUFFER_HINT); - return *this; - } - - inline WriteOptions& clear_corked() { - ClearBit(GRPC_WRITE_BUFFER_HINT); - return *this; - } - - inline bool is_corked() const { return GetBit(GRPC_WRITE_BUFFER_HINT); } - - /// last-message bit: indicates this is the last message in a stream - /// client-side: makes Write the equivalent of performing Write, WritesDone - /// in a single step - /// server-side: hold the Write until the service handler returns (sync api) - /// or until Finish is called (async api) - inline WriteOptions& set_last_message() { - last_message_ = true; - return *this; - } - - /// Clears flag indicating that this is the last message in a stream, - /// disabling coalescing. - inline WriteOptions& clear_last_message() { - last_message_ = false; - return *this; - } - - /// Guarantee that all bytes have been written to the socket before completing - /// this write (usually writes are completed when they pass flow control). - inline WriteOptions& set_write_through() { - SetBit(GRPC_WRITE_THROUGH); - return *this; - } - - inline bool is_write_through() const { return GetBit(GRPC_WRITE_THROUGH); } - - /// Get value for the flag indicating that this is the last message, and - /// should be coalesced with trailing metadata. - /// - /// \sa GRPC_WRITE_LAST_MESSAGE - bool is_last_message() const { return last_message_; } - - WriteOptions& operator=(const WriteOptions& rhs) { - flags_ = rhs.flags_; - return *this; - } - - private: - void SetBit(const uint32_t mask) { flags_ |= mask; } - - void ClearBit(const uint32_t mask) { flags_ &= ~mask; } - - bool GetBit(const uint32_t mask) const { return (flags_ & mask) != 0; } - - uint32_t flags_; - bool last_message_; -}; - -namespace internal { -/// Default argument for CallOpSet. I is unused by the class, but can be -/// used for generating multiple names for the same thing. -template -class CallNoOp { - protected: - void AddOp(grpc_op* ops, size_t* nops) {} - void FinishOp(bool* status) {} -}; - -class CallOpSendInitialMetadata { - public: - CallOpSendInitialMetadata() : send_(false) { - maybe_compression_level_.is_set = false; - } - - void SendInitialMetadata( - const std::multimap& metadata, - uint32_t flags) { - maybe_compression_level_.is_set = false; - send_ = true; - flags_ = flags; - initial_metadata_ = - FillMetadataArray(metadata, &initial_metadata_count_, ""); - } - - void set_compression_level(grpc_compression_level level) { - maybe_compression_level_.is_set = true; - maybe_compression_level_.level = level; - } - - protected: - void AddOp(grpc_op* ops, size_t* nops) { - if (!send_) return; - grpc_op* op = &ops[(*nops)++]; - op->op = GRPC_OP_SEND_INITIAL_METADATA; - op->flags = flags_; - op->reserved = NULL; - op->data.send_initial_metadata.count = initial_metadata_count_; - op->data.send_initial_metadata.metadata = initial_metadata_; - op->data.send_initial_metadata.maybe_compression_level.is_set = - maybe_compression_level_.is_set; - if (maybe_compression_level_.is_set) { - op->data.send_initial_metadata.maybe_compression_level.level = - maybe_compression_level_.level; - } - } - void FinishOp(bool* status) { - if (!send_) return; - g_core_codegen_interface->gpr_free(initial_metadata_); - send_ = false; - } - - bool send_; - uint32_t flags_; - size_t initial_metadata_count_; - grpc_metadata* initial_metadata_; - struct { - bool is_set; - grpc_compression_level level; - } maybe_compression_level_; -}; - -class CallOpSendMessage { - public: - CallOpSendMessage() : send_buf_() {} - - /// Send \a message using \a options for the write. The \a options are cleared - /// after use. - template - Status SendMessage(const M& message, - WriteOptions options) GRPC_MUST_USE_RESULT; - - template - Status SendMessage(const M& message) GRPC_MUST_USE_RESULT; - - protected: - void AddOp(grpc_op* ops, size_t* nops) { - if (!send_buf_.Valid()) return; - grpc_op* op = &ops[(*nops)++]; - op->op = GRPC_OP_SEND_MESSAGE; - op->flags = write_options_.flags(); - op->reserved = NULL; - op->data.send_message.send_message = send_buf_.c_buffer(); - // Flags are per-message: clear them after use. - write_options_.Clear(); - } - void FinishOp(bool* status) { send_buf_.Clear(); } - - private: - ByteBuffer send_buf_; - WriteOptions write_options_; -}; - -template -Status CallOpSendMessage::SendMessage(const M& message, WriteOptions options) { - write_options_ = options; - bool own_buf; - // TODO(vjpai): Remove the void below when possible - // The void in the template parameter below should not be needed - // (since it should be implicit) but is needed due to an observed - // difference in behavior between clang and gcc for certain internal users - Status result = SerializationTraits::Serialize( - message, send_buf_.bbuf_ptr(), &own_buf); - if (!own_buf) { - send_buf_.Duplicate(); - } - return result; -} - -template -Status CallOpSendMessage::SendMessage(const M& message) { - return SendMessage(message, WriteOptions()); -} - -template -class CallOpRecvMessage { - public: - CallOpRecvMessage() - : got_message(false), - message_(nullptr), - allow_not_getting_message_(false) {} - - void RecvMessage(R* message) { message_ = message; } - - // Do not change status if no message is received. - void AllowNoMessage() { allow_not_getting_message_ = true; } - - bool got_message; - - protected: - void AddOp(grpc_op* ops, size_t* nops) { - if (message_ == nullptr) return; - grpc_op* op = &ops[(*nops)++]; - op->op = GRPC_OP_RECV_MESSAGE; - op->flags = 0; - op->reserved = NULL; - op->data.recv_message.recv_message = recv_buf_.c_buffer_ptr(); - } - - void FinishOp(bool* status) { - if (message_ == nullptr) return; - if (recv_buf_.Valid()) { - if (*status) { - got_message = *status = - SerializationTraits::Deserialize(recv_buf_.bbuf_ptr(), message_) - .ok(); - recv_buf_.Release(); - } else { - got_message = false; - recv_buf_.Clear(); - } - } else { - got_message = false; - if (!allow_not_getting_message_) { - *status = false; - } - } - message_ = nullptr; - } - - private: - R* message_; - ByteBuffer recv_buf_; - bool allow_not_getting_message_; -}; - -class DeserializeFunc { - public: - virtual Status Deserialize(ByteBuffer* buf) = 0; - virtual ~DeserializeFunc() {} -}; - -template -class DeserializeFuncType final : public DeserializeFunc { - public: - DeserializeFuncType(R* message) : message_(message) {} - Status Deserialize(ByteBuffer* buf) override { - return SerializationTraits::Deserialize(buf->bbuf_ptr(), message_); - } - - ~DeserializeFuncType() override {} - - private: - R* message_; // Not a managed pointer because management is external to this -}; - -class CallOpGenericRecvMessage { - public: - CallOpGenericRecvMessage() - : got_message(false), allow_not_getting_message_(false) {} - - template - void RecvMessage(R* message) { - // Use an explicit base class pointer to avoid resolution error in the - // following unique_ptr::reset for some old implementations. - DeserializeFunc* func = new DeserializeFuncType(message); - deserialize_.reset(func); - } - - // Do not change status if no message is received. - void AllowNoMessage() { allow_not_getting_message_ = true; } - - bool got_message; - - protected: - void AddOp(grpc_op* ops, size_t* nops) { - if (!deserialize_) return; - grpc_op* op = &ops[(*nops)++]; - op->op = GRPC_OP_RECV_MESSAGE; - op->flags = 0; - op->reserved = NULL; - op->data.recv_message.recv_message = recv_buf_.c_buffer_ptr(); - } - - void FinishOp(bool* status) { - if (!deserialize_) return; - if (recv_buf_.Valid()) { - if (*status) { - got_message = true; - *status = deserialize_->Deserialize(&recv_buf_).ok(); - recv_buf_.Release(); - } else { - got_message = false; - recv_buf_.Clear(); - } - } else { - got_message = false; - if (!allow_not_getting_message_) { - *status = false; - } - } - deserialize_.reset(); - } - - private: - std::unique_ptr deserialize_; - ByteBuffer recv_buf_; - bool allow_not_getting_message_; -}; - -class CallOpClientSendClose { - public: - CallOpClientSendClose() : send_(false) {} - - void ClientSendClose() { send_ = true; } - - protected: - void AddOp(grpc_op* ops, size_t* nops) { - if (!send_) return; - grpc_op* op = &ops[(*nops)++]; - op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; - op->flags = 0; - op->reserved = NULL; - } - void FinishOp(bool* status) { send_ = false; } - - private: - bool send_; -}; - -class CallOpServerSendStatus { - public: - CallOpServerSendStatus() : send_status_available_(false) {} - - void ServerSendStatus( - const std::multimap& trailing_metadata, - const Status& status) { - send_error_details_ = status.error_details(); - trailing_metadata_ = FillMetadataArray( - trailing_metadata, &trailing_metadata_count_, send_error_details_); - send_status_available_ = true; - send_status_code_ = static_cast(status.error_code()); - send_error_message_ = status.error_message(); - } - - protected: - void AddOp(grpc_op* ops, size_t* nops) { - if (!send_status_available_) return; - grpc_op* op = &ops[(*nops)++]; - op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; - op->data.send_status_from_server.trailing_metadata_count = - trailing_metadata_count_; - op->data.send_status_from_server.trailing_metadata = trailing_metadata_; - op->data.send_status_from_server.status = send_status_code_; - error_message_slice_ = SliceReferencingString(send_error_message_); - op->data.send_status_from_server.status_details = - send_error_message_.empty() ? nullptr : &error_message_slice_; - op->flags = 0; - op->reserved = NULL; - } - - void FinishOp(bool* status) { - if (!send_status_available_) return; - g_core_codegen_interface->gpr_free(trailing_metadata_); - send_status_available_ = false; - } - - private: - bool send_status_available_; - grpc_status_code send_status_code_; - grpc::string send_error_details_; - grpc::string send_error_message_; - size_t trailing_metadata_count_; - grpc_metadata* trailing_metadata_; - grpc_slice error_message_slice_; -}; - -class CallOpRecvInitialMetadata { - public: - CallOpRecvInitialMetadata() : metadata_map_(nullptr) {} - - void RecvInitialMetadata(ClientContext* context) { - context->initial_metadata_received_ = true; - metadata_map_ = &context->recv_initial_metadata_; - } - - protected: - void AddOp(grpc_op* ops, size_t* nops) { - if (metadata_map_ == nullptr) return; - grpc_op* op = &ops[(*nops)++]; - op->op = GRPC_OP_RECV_INITIAL_METADATA; - op->data.recv_initial_metadata.recv_initial_metadata = metadata_map_->arr(); - op->flags = 0; - op->reserved = NULL; - } - - void FinishOp(bool* status) { - if (metadata_map_ == nullptr) return; - metadata_map_ = nullptr; - } - - private: - MetadataMap* metadata_map_; -}; - -class CallOpClientRecvStatus { - public: - CallOpClientRecvStatus() - : recv_status_(nullptr), debug_error_string_(nullptr) {} - - void ClientRecvStatus(ClientContext* context, Status* status) { - client_context_ = context; - metadata_map_ = &client_context_->trailing_metadata_; - recv_status_ = status; - error_message_ = g_core_codegen_interface->grpc_empty_slice(); - } - - protected: - void AddOp(grpc_op* ops, size_t* nops) { - if (recv_status_ == nullptr) return; - grpc_op* op = &ops[(*nops)++]; - op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; - op->data.recv_status_on_client.trailing_metadata = metadata_map_->arr(); - op->data.recv_status_on_client.status = &status_code_; - op->data.recv_status_on_client.status_details = &error_message_; - op->data.recv_status_on_client.error_string = &debug_error_string_; - op->flags = 0; - op->reserved = NULL; - } - - void FinishOp(bool* status) { - if (recv_status_ == nullptr) return; - grpc::string binary_error_details = metadata_map_->GetBinaryErrorDetails(); - *recv_status_ = - Status(static_cast(status_code_), - GRPC_SLICE_IS_EMPTY(error_message_) - ? grpc::string() - : grpc::string(GRPC_SLICE_START_PTR(error_message_), - GRPC_SLICE_END_PTR(error_message_)), - binary_error_details); - client_context_->set_debug_error_string( - debug_error_string_ != nullptr ? debug_error_string_ : ""); - g_core_codegen_interface->grpc_slice_unref(error_message_); - if (debug_error_string_ != nullptr) { - g_core_codegen_interface->gpr_free((void*)debug_error_string_); - } - recv_status_ = nullptr; - } - - private: - ClientContext* client_context_; - MetadataMap* metadata_map_; - Status* recv_status_; - const char* debug_error_string_; - grpc_status_code status_code_; - grpc_slice error_message_; -}; - -/// An abstract collection of call ops, used to generate the -/// grpc_call_op structure to pass down to the lower layers, -/// and as it is-a CompletionQueueTag, also massages the final -/// completion into the correct form for consumption in the C++ -/// API. -class CallOpSetInterface : public CompletionQueueTag { - public: - /// Fills in grpc_op, starting from ops[*nops] and moving - /// upwards. - virtual void FillOps(grpc_call* call, grpc_op* ops, size_t* nops) = 0; - - /// Get the tag to be used at the core completion queue. Generally, the - /// value of cq_tag will be "this". However, it can be overridden if we - /// want core to process the tag differently (e.g., as a core callback) - virtual void* cq_tag() = 0; -}; - -/// Primary implementation of CallOpSetInterface. -/// Since we cannot use variadic templates, we declare slots up to -/// the maximum count of ops we'll need in a set. We leverage the -/// empty base class optimization to slim this class (especially -/// when there are many unused slots used). To avoid duplicate base classes, -/// the template parmeter for CallNoOp is varied by argument position. -template , class Op2 = CallNoOp<2>, - class Op3 = CallNoOp<3>, class Op4 = CallNoOp<4>, - class Op5 = CallNoOp<5>, class Op6 = CallNoOp<6>> -class CallOpSet : public CallOpSetInterface, - public Op1, - public Op2, - public Op3, - public Op4, - public Op5, - public Op6 { - public: - CallOpSet() : cq_tag_(this), return_tag_(this), call_(nullptr) {} - - // The copy constructor and assignment operator reset the value of - // cq_tag_ and return_tag_ since those are only meaningful on a specific - // object, not across objects. - CallOpSet(const CallOpSet& other) - : cq_tag_(this), return_tag_(this), call_(other.call_) {} - CallOpSet& operator=(const CallOpSet& other) { - cq_tag_ = this; - return_tag_ = this; - call_ = other.call_; - return *this; - } - - void FillOps(grpc_call* call, grpc_op* ops, size_t* nops) override { - this->Op1::AddOp(ops, nops); - this->Op2::AddOp(ops, nops); - this->Op3::AddOp(ops, nops); - this->Op4::AddOp(ops, nops); - this->Op5::AddOp(ops, nops); - this->Op6::AddOp(ops, nops); - g_core_codegen_interface->grpc_call_ref(call); - call_ = call; - } - - bool FinalizeResult(void** tag, bool* status) override { - this->Op1::FinishOp(status); - this->Op2::FinishOp(status); - this->Op3::FinishOp(status); - this->Op4::FinishOp(status); - this->Op5::FinishOp(status); - this->Op6::FinishOp(status); - *tag = return_tag_; - - g_core_codegen_interface->grpc_call_unref(call_); - return true; - } - - void set_output_tag(void* return_tag) { return_tag_ = return_tag; } - - void* cq_tag() override { return cq_tag_; } - - /// set_cq_tag is used to provide a different core CQ tag than "this". - /// This is used for callback-based tags, where the core tag is the core - /// callback function. It does not change the use or behavior of any other - /// function (such as FinalizeResult) - void set_cq_tag(void* cq_tag) { cq_tag_ = cq_tag; } - - private: - void* cq_tag_; - void* return_tag_; - grpc_call* call_; -}; +class CallOpSetInterface; /// Straightforward wrapping of the C call object class Call final { public: + Call() + : call_hook_(nullptr), + cq_(nullptr), + call_(nullptr), + max_receive_message_size_(-1) {} /** call is owned by the caller */ Call(grpc_call* call, CallHook* call_hook, CompletionQueue* cq) : call_hook_(call_hook), @@ -688,11 +48,20 @@ class Call final { max_receive_message_size_(-1) {} Call(grpc_call* call, CallHook* call_hook, CompletionQueue* cq, - int max_receive_message_size) + experimental::ClientRpcInfo* rpc_info) : call_hook_(call_hook), cq_(cq), call_(call), - max_receive_message_size_(max_receive_message_size) {} + max_receive_message_size_(-1), + client_rpc_info_(rpc_info) {} + + Call(grpc_call* call, CallHook* call_hook, CompletionQueue* cq, + int max_receive_message_size, experimental::ServerRpcInfo* rpc_info) + : call_hook_(call_hook), + cq_(cq), + call_(call), + max_receive_message_size_(max_receive_message_size), + server_rpc_info_(rpc_info) {} void PerformOps(CallOpSetInterface* ops) { call_hook_->PerformOpsOnCall(ops, this); @@ -703,11 +72,21 @@ class Call final { int max_receive_message_size() const { return max_receive_message_size_; } + experimental::ClientRpcInfo* client_rpc_info() const { + return client_rpc_info_; + } + + experimental::ServerRpcInfo* server_rpc_info() const { + return server_rpc_info_; + } + private: CallHook* call_hook_; CompletionQueue* cq_; grpc_call* call_; int max_receive_message_size_; + experimental::ClientRpcInfo* client_rpc_info_ = nullptr; + experimental::ServerRpcInfo* server_rpc_info_ = nullptr; }; } // namespace internal } // namespace grpc diff --git a/include/grpcpp/impl/codegen/call_op_set.h b/include/grpcpp/impl/codegen/call_op_set.h new file mode 100644 index 00000000000..785688e67f3 --- /dev/null +++ b/include/grpcpp/impl/codegen/call_op_set.h @@ -0,0 +1,920 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPCPP_IMPL_CODEGEN_CALL_OP_SET_H +#define GRPCPP_IMPL_CODEGEN_CALL_OP_SET_H + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace grpc { + +class CompletionQueue; +extern CoreCodegenInterface* g_core_codegen_interface; + +namespace internal { +class Call; +class CallHook; + +// TODO(yangg) if the map is changed before we send, the pointers will be a +// mess. Make sure it does not happen. +inline grpc_metadata* FillMetadataArray( + const std::multimap& metadata, + size_t* metadata_count, const grpc::string& optional_error_details) { + *metadata_count = metadata.size() + (optional_error_details.empty() ? 0 : 1); + if (*metadata_count == 0) { + return nullptr; + } + grpc_metadata* metadata_array = + (grpc_metadata*)(g_core_codegen_interface->gpr_malloc( + (*metadata_count) * sizeof(grpc_metadata))); + size_t i = 0; + for (auto iter = metadata.cbegin(); iter != metadata.cend(); ++iter, ++i) { + metadata_array[i].key = SliceReferencingString(iter->first); + metadata_array[i].value = SliceReferencingString(iter->second); + } + if (!optional_error_details.empty()) { + metadata_array[i].key = + g_core_codegen_interface->grpc_slice_from_static_buffer( + kBinaryErrorDetailsKey, sizeof(kBinaryErrorDetailsKey) - 1); + metadata_array[i].value = SliceReferencingString(optional_error_details); + } + return metadata_array; +} +} // namespace internal + +/// Per-message write options. +class WriteOptions { + public: + WriteOptions() : flags_(0), last_message_(false) {} + WriteOptions(const WriteOptions& other) + : flags_(other.flags_), last_message_(other.last_message_) {} + + /// Clear all flags. + inline void Clear() { flags_ = 0; } + + /// Returns raw flags bitset. + inline uint32_t flags() const { return flags_; } + + /// Sets flag for the disabling of compression for the next message write. + /// + /// \sa GRPC_WRITE_NO_COMPRESS + inline WriteOptions& set_no_compression() { + SetBit(GRPC_WRITE_NO_COMPRESS); + return *this; + } + + /// Clears flag for the disabling of compression for the next message write. + /// + /// \sa GRPC_WRITE_NO_COMPRESS + inline WriteOptions& clear_no_compression() { + ClearBit(GRPC_WRITE_NO_COMPRESS); + return *this; + } + + /// Get value for the flag indicating whether compression for the next + /// message write is forcefully disabled. + /// + /// \sa GRPC_WRITE_NO_COMPRESS + inline bool get_no_compression() const { + return GetBit(GRPC_WRITE_NO_COMPRESS); + } + + /// Sets flag indicating that the write may be buffered and need not go out on + /// the wire immediately. + /// + /// \sa GRPC_WRITE_BUFFER_HINT + inline WriteOptions& set_buffer_hint() { + SetBit(GRPC_WRITE_BUFFER_HINT); + return *this; + } + + /// Clears flag indicating that the write may be buffered and need not go out + /// on the wire immediately. + /// + /// \sa GRPC_WRITE_BUFFER_HINT + inline WriteOptions& clear_buffer_hint() { + ClearBit(GRPC_WRITE_BUFFER_HINT); + return *this; + } + + /// Get value for the flag indicating that the write may be buffered and need + /// not go out on the wire immediately. + /// + /// \sa GRPC_WRITE_BUFFER_HINT + inline bool get_buffer_hint() const { return GetBit(GRPC_WRITE_BUFFER_HINT); } + + /// corked bit: aliases set_buffer_hint currently, with the intent that + /// set_buffer_hint will be removed in the future + inline WriteOptions& set_corked() { + SetBit(GRPC_WRITE_BUFFER_HINT); + return *this; + } + + inline WriteOptions& clear_corked() { + ClearBit(GRPC_WRITE_BUFFER_HINT); + return *this; + } + + inline bool is_corked() const { return GetBit(GRPC_WRITE_BUFFER_HINT); } + + /// last-message bit: indicates this is the last message in a stream + /// client-side: makes Write the equivalent of performing Write, WritesDone + /// in a single step + /// server-side: hold the Write until the service handler returns (sync api) + /// or until Finish is called (async api) + inline WriteOptions& set_last_message() { + last_message_ = true; + return *this; + } + + /// Clears flag indicating that this is the last message in a stream, + /// disabling coalescing. + inline WriteOptions& clear_last_message() { + last_message_ = false; + return *this; + } + + /// Guarantee that all bytes have been written to the socket before completing + /// this write (usually writes are completed when they pass flow control). + inline WriteOptions& set_write_through() { + SetBit(GRPC_WRITE_THROUGH); + return *this; + } + + inline bool is_write_through() const { return GetBit(GRPC_WRITE_THROUGH); } + + /// Get value for the flag indicating that this is the last message, and + /// should be coalesced with trailing metadata. + /// + /// \sa GRPC_WRITE_LAST_MESSAGE + bool is_last_message() const { return last_message_; } + + WriteOptions& operator=(const WriteOptions& rhs) { + flags_ = rhs.flags_; + return *this; + } + + private: + void SetBit(const uint32_t mask) { flags_ |= mask; } + + void ClearBit(const uint32_t mask) { flags_ &= ~mask; } + + bool GetBit(const uint32_t mask) const { return (flags_ & mask) != 0; } + + uint32_t flags_; + bool last_message_; +}; + +namespace internal { + +/// Default argument for CallOpSet. I is unused by the class, but can be +/// used for generating multiple names for the same thing. +template +class CallNoOp { + protected: + void AddOp(grpc_op* ops, size_t* nops) {} + void FinishOp(bool* status) {} + void SetInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) {} + void SetFinishInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) {} + void SetHijackingState(InternalInterceptorBatchMethods* interceptor_methods) { + } +}; + +class CallOpSendInitialMetadata { + public: + CallOpSendInitialMetadata() : send_(false) { + maybe_compression_level_.is_set = false; + } + + void SendInitialMetadata(std::multimap* metadata, + uint32_t flags) { + maybe_compression_level_.is_set = false; + send_ = true; + flags_ = flags; + metadata_map_ = metadata; + } + + void set_compression_level(grpc_compression_level level) { + maybe_compression_level_.is_set = true; + maybe_compression_level_.level = level; + } + + protected: + void AddOp(grpc_op* ops, size_t* nops) { + if (!send_ || hijacked_) return; + grpc_op* op = &ops[(*nops)++]; + op->op = GRPC_OP_SEND_INITIAL_METADATA; + op->flags = flags_; + op->reserved = NULL; + initial_metadata_ = + FillMetadataArray(*metadata_map_, &initial_metadata_count_, ""); + op->data.send_initial_metadata.count = initial_metadata_count_; + op->data.send_initial_metadata.metadata = initial_metadata_; + op->data.send_initial_metadata.maybe_compression_level.is_set = + maybe_compression_level_.is_set; + if (maybe_compression_level_.is_set) { + op->data.send_initial_metadata.maybe_compression_level.level = + maybe_compression_level_.level; + } + } + void FinishOp(bool* status) { + if (!send_ || hijacked_) return; + g_core_codegen_interface->gpr_free(initial_metadata_); + send_ = false; + } + + void SetInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) { + if (!send_) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA); + interceptor_methods->SetSendInitialMetadata(metadata_map_); + } + + void SetFinishInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) {} + + void SetHijackingState(InternalInterceptorBatchMethods* interceptor_methods) { + hijacked_ = true; + } + + bool hijacked_ = false; + bool send_; + uint32_t flags_; + size_t initial_metadata_count_; + std::multimap* metadata_map_; + grpc_metadata* initial_metadata_; + struct { + bool is_set; + grpc_compression_level level; + } maybe_compression_level_; +}; + +class CallOpSendMessage { + public: + CallOpSendMessage() : send_buf_() {} + + /// Send \a message using \a options for the write. The \a options are cleared + /// after use. + template + Status SendMessage(const M& message, + WriteOptions options) GRPC_MUST_USE_RESULT; + + template + Status SendMessage(const M& message) GRPC_MUST_USE_RESULT; + + protected: + void AddOp(grpc_op* ops, size_t* nops) { + if (!send_buf_.Valid() || hijacked_) return; + grpc_op* op = &ops[(*nops)++]; + op->op = GRPC_OP_SEND_MESSAGE; + op->flags = write_options_.flags(); + op->reserved = NULL; + op->data.send_message.send_message = send_buf_.c_buffer(); + // Flags are per-message: clear them after use. + write_options_.Clear(); + } + void FinishOp(bool* status) { send_buf_.Clear(); } + + void SetInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) { + if (!send_buf_.Valid()) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE); + interceptor_methods->SetSendMessage(&send_buf_); + } + + void SetFinishInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) {} + + void SetHijackingState(InternalInterceptorBatchMethods* interceptor_methods) { + hijacked_ = true; + } + + private: + bool hijacked_ = false; + ByteBuffer send_buf_; + WriteOptions write_options_; +}; + +template +Status CallOpSendMessage::SendMessage(const M& message, WriteOptions options) { + write_options_ = options; + bool own_buf; + // TODO(vjpai): Remove the void below when possible + // The void in the template parameter below should not be needed + // (since it should be implicit) but is needed due to an observed + // difference in behavior between clang and gcc for certain internal users + Status result = SerializationTraits::Serialize( + message, send_buf_.bbuf_ptr(), &own_buf); + if (!own_buf) { + send_buf_.Duplicate(); + } + return result; +} + +template +Status CallOpSendMessage::SendMessage(const M& message) { + return SendMessage(message, WriteOptions()); +} + +template +class CallOpRecvMessage { + public: + CallOpRecvMessage() + : got_message(false), + message_(nullptr), + allow_not_getting_message_(false) {} + + void RecvMessage(R* message) { message_ = message; } + + // Do not change status if no message is received. + void AllowNoMessage() { allow_not_getting_message_ = true; } + + bool got_message; + + protected: + void AddOp(grpc_op* ops, size_t* nops) { + if (message_ == nullptr || hijacked_) return; + grpc_op* op = &ops[(*nops)++]; + op->op = GRPC_OP_RECV_MESSAGE; + op->flags = 0; + op->reserved = NULL; + op->data.recv_message.recv_message = recv_buf_.c_buffer_ptr(); + } + + void FinishOp(bool* status) { + if (message_ == nullptr || hijacked_) return; + if (recv_buf_.Valid()) { + if (*status) { + got_message = *status = + SerializationTraits::Deserialize(recv_buf_.bbuf_ptr(), message_) + .ok(); + recv_buf_.Release(); + } else { + got_message = false; + recv_buf_.Clear(); + } + } else { + got_message = false; + if (!allow_not_getting_message_) { + *status = false; + } + } + message_ = nullptr; + } + + void SetInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) { + interceptor_methods->SetRecvMessage(message_); + } + + void SetFinishInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) { + if (!got_message) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE); + } + void SetHijackingState(InternalInterceptorBatchMethods* interceptor_methods) { + hijacked_ = true; + if (message_ == nullptr) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_MESSAGE); + got_message = true; + } + + private: + R* message_; + ByteBuffer recv_buf_; + bool allow_not_getting_message_; + bool hijacked_ = false; +}; + +class DeserializeFunc { + public: + virtual Status Deserialize(ByteBuffer* buf) = 0; + virtual ~DeserializeFunc() {} +}; + +template +class DeserializeFuncType final : public DeserializeFunc { + public: + DeserializeFuncType(R* message) : message_(message) {} + Status Deserialize(ByteBuffer* buf) override { + return SerializationTraits::Deserialize(buf->bbuf_ptr(), message_); + } + + ~DeserializeFuncType() override {} + + private: + R* message_; // Not a managed pointer because management is external to this +}; + +class CallOpGenericRecvMessage { + public: + CallOpGenericRecvMessage() + : got_message(false), allow_not_getting_message_(false) {} + + template + void RecvMessage(R* message) { + // Use an explicit base class pointer to avoid resolution error in the + // following unique_ptr::reset for some old implementations. + DeserializeFunc* func = new DeserializeFuncType(message); + deserialize_.reset(func); + message_ = message; + } + + // Do not change status if no message is received. + void AllowNoMessage() { allow_not_getting_message_ = true; } + + bool got_message; + + protected: + void AddOp(grpc_op* ops, size_t* nops) { + if (!deserialize_ || hijacked_) return; + grpc_op* op = &ops[(*nops)++]; + op->op = GRPC_OP_RECV_MESSAGE; + op->flags = 0; + op->reserved = NULL; + op->data.recv_message.recv_message = recv_buf_.c_buffer_ptr(); + } + + void FinishOp(bool* status) { + if (!deserialize_ || hijacked_) return; + if (recv_buf_.Valid()) { + if (*status) { + got_message = true; + *status = deserialize_->Deserialize(&recv_buf_).ok(); + recv_buf_.Release(); + } else { + got_message = false; + recv_buf_.Clear(); + } + } else { + got_message = false; + if (!allow_not_getting_message_) { + *status = false; + } + } + deserialize_.reset(); + } + + void SetInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) { + interceptor_methods->SetRecvMessage(message_); + } + + void SetFinishInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) { + if (!got_message) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE); + } + void SetHijackingState(InternalInterceptorBatchMethods* interceptor_methods) { + hijacked_ = true; + if (!deserialize_) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_MESSAGE); + } + + private: + void* message_; + bool hijacked_ = false; + std::unique_ptr deserialize_; + ByteBuffer recv_buf_; + bool allow_not_getting_message_; +}; + +class CallOpClientSendClose { + public: + CallOpClientSendClose() : send_(false) {} + + void ClientSendClose() { send_ = true; } + + protected: + void AddOp(grpc_op* ops, size_t* nops) { + if (!send_ || hijacked_) return; + grpc_op* op = &ops[(*nops)++]; + op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; + op->flags = 0; + op->reserved = NULL; + } + void FinishOp(bool* status) { send_ = false; } + + void SetInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) { + if (!send_) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_CLOSE); + } + + void SetFinishInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) {} + + void SetHijackingState(InternalInterceptorBatchMethods* interceptor_methods) { + hijacked_ = true; + } + + private: + bool hijacked_ = false; + bool send_; +}; + +class CallOpServerSendStatus { + public: + CallOpServerSendStatus() : send_status_available_(false) {} + + void ServerSendStatus( + std::multimap* trailing_metadata, + const Status& status) { + send_error_details_ = status.error_details(); + metadata_map_ = trailing_metadata; + send_status_available_ = true; + send_status_code_ = static_cast(status.error_code()); + send_error_message_ = status.error_message(); + } + + protected: + void AddOp(grpc_op* ops, size_t* nops) { + if (!send_status_available_ || hijacked_) return; + trailing_metadata_ = FillMetadataArray( + *metadata_map_, &trailing_metadata_count_, send_error_details_); + grpc_op* op = &ops[(*nops)++]; + op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; + op->data.send_status_from_server.trailing_metadata_count = + trailing_metadata_count_; + op->data.send_status_from_server.trailing_metadata = trailing_metadata_; + op->data.send_status_from_server.status = send_status_code_; + error_message_slice_ = SliceReferencingString(send_error_message_); + op->data.send_status_from_server.status_details = + send_error_message_.empty() ? nullptr : &error_message_slice_; + op->flags = 0; + op->reserved = NULL; + } + + void FinishOp(bool* status) { + if (!send_status_available_ || hijacked_) return; + g_core_codegen_interface->gpr_free(trailing_metadata_); + send_status_available_ = false; + } + + void SetInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) { + if (!send_status_available_) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_STATUS); + interceptor_methods->SetSendTrailingMetadata(metadata_map_); + interceptor_methods->SetSendStatus(&send_status_code_, &send_error_details_, + &send_error_message_); + } + + void SetFinishInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) {} + + void SetHijackingState(InternalInterceptorBatchMethods* interceptor_methods) { + hijacked_ = true; + } + + private: + bool hijacked_ = false; + bool send_status_available_; + grpc_status_code send_status_code_; + grpc::string send_error_details_; + grpc::string send_error_message_; + size_t trailing_metadata_count_; + std::multimap* metadata_map_; + grpc_metadata* trailing_metadata_; + grpc_slice error_message_slice_; +}; + +class CallOpRecvInitialMetadata { + public: + CallOpRecvInitialMetadata() : metadata_map_(nullptr) {} + + void RecvInitialMetadata(ClientContext* context) { + context->initial_metadata_received_ = true; + metadata_map_ = &context->recv_initial_metadata_; + } + + protected: + void AddOp(grpc_op* ops, size_t* nops) { + if (metadata_map_ == nullptr || hijacked_) return; + grpc_op* op = &ops[(*nops)++]; + op->op = GRPC_OP_RECV_INITIAL_METADATA; + op->data.recv_initial_metadata.recv_initial_metadata = metadata_map_->arr(); + op->flags = 0; + op->reserved = NULL; + } + + void FinishOp(bool* status) { + if (metadata_map_ == nullptr || hijacked_) return; + } + + void SetInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) { + interceptor_methods->SetRecvInitialMetadata(metadata_map_); + } + + void SetFinishInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) { + if (metadata_map_ == nullptr) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA); + metadata_map_ = nullptr; + } + + void SetHijackingState(InternalInterceptorBatchMethods* interceptor_methods) { + hijacked_ = true; + if (metadata_map_ == nullptr) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA); + } + + private: + bool hijacked_ = false; + MetadataMap* metadata_map_; +}; + +class CallOpClientRecvStatus { + public: + CallOpClientRecvStatus() + : recv_status_(nullptr), debug_error_string_(nullptr) {} + + void ClientRecvStatus(ClientContext* context, Status* status) { + client_context_ = context; + metadata_map_ = &client_context_->trailing_metadata_; + recv_status_ = status; + error_message_ = g_core_codegen_interface->grpc_empty_slice(); + } + + protected: + void AddOp(grpc_op* ops, size_t* nops) { + if (recv_status_ == nullptr || hijacked_) return; + grpc_op* op = &ops[(*nops)++]; + op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; + op->data.recv_status_on_client.trailing_metadata = metadata_map_->arr(); + op->data.recv_status_on_client.status = &status_code_; + op->data.recv_status_on_client.status_details = &error_message_; + op->data.recv_status_on_client.error_string = &debug_error_string_; + op->flags = 0; + op->reserved = NULL; + } + + void FinishOp(bool* status) { + if (recv_status_ == nullptr || hijacked_) return; + grpc::string binary_error_details = metadata_map_->GetBinaryErrorDetails(); + *recv_status_ = + Status(static_cast(status_code_), + GRPC_SLICE_IS_EMPTY(error_message_) + ? grpc::string() + : grpc::string(GRPC_SLICE_START_PTR(error_message_), + GRPC_SLICE_END_PTR(error_message_)), + binary_error_details); + client_context_->set_debug_error_string( + debug_error_string_ != nullptr ? debug_error_string_ : ""); + g_core_codegen_interface->grpc_slice_unref(error_message_); + if (debug_error_string_ != nullptr) { + g_core_codegen_interface->gpr_free((void*)debug_error_string_); + } + } + + void SetInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) { + interceptor_methods->SetRecvStatus(recv_status_); + interceptor_methods->SetRecvTrailingMetadata(metadata_map_); + } + + void SetFinishInterceptionHookPoint( + InternalInterceptorBatchMethods* interceptor_methods) { + if (recv_status_ == nullptr) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_STATUS); + recv_status_ = nullptr; + } + + void SetHijackingState(InternalInterceptorBatchMethods* interceptor_methods) { + hijacked_ = true; + if (recv_status_ == nullptr) return; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_STATUS); + } + + private: + bool hijacked_ = false; + ClientContext* client_context_; + MetadataMap* metadata_map_; + Status* recv_status_; + const char* debug_error_string_; + grpc_status_code status_code_; + grpc_slice error_message_; +}; + +template , class Op2 = CallNoOp<2>, + class Op3 = CallNoOp<3>, class Op4 = CallNoOp<4>, + class Op5 = CallNoOp<5>, class Op6 = CallNoOp<6>> +class CallOpSet; + +/// Primary implementation of CallOpSetInterface. +/// Since we cannot use variadic templates, we declare slots up to +/// the maximum count of ops we'll need in a set. We leverage the +/// empty base class optimization to slim this class (especially +/// when there are many unused slots used). To avoid duplicate base classes, +/// the template parmeter for CallNoOp is varied by argument position. +template +class CallOpSet : public CallOpSetInterface, + public Op1, + public Op2, + public Op3, + public Op4, + public Op5, + public Op6 { + public: + CallOpSet() : cq_tag_(this), return_tag_(this) {} + // The copy constructor and assignment operator reset the value of + // cq_tag_, return_tag_, done_intercepting_ and interceptor_methods_ since + // those are only meaningful on a specific object, not across objects. + CallOpSet(const CallOpSet& other) + : cq_tag_(this), + return_tag_(this), + call_(other.call_), + done_intercepting_(false), + interceptor_methods_(InterceptorBatchMethodsImpl()) {} + + CallOpSet& operator=(const CallOpSet& other) { + cq_tag_ = this; + return_tag_ = this; + call_ = other.call_; + done_intercepting_ = false; + interceptor_methods_ = InterceptorBatchMethodsImpl(); + return *this; + } + + void FillOps(Call* call) override { + done_intercepting_ = false; + g_core_codegen_interface->grpc_call_ref(call->call()); + call_ = + *call; // It's fine to create a copy of call since it's just pointers + + if (RunInterceptors()) { + ContinueFillOpsAfterInterception(); + } else { + // After the interceptors are run, ContinueFillOpsAfterInterception will + // be run + } + } + + bool FinalizeResult(void** tag, bool* status) override { + if (done_intercepting_) { + // We have already finished intercepting and filling in the results. This + // round trip from the core needed to be made because interceptors were + // run + *tag = return_tag_; + *status = saved_status_; + g_core_codegen_interface->grpc_call_unref(call_.call()); + return true; + } + + this->Op1::FinishOp(status); + this->Op2::FinishOp(status); + this->Op3::FinishOp(status); + this->Op4::FinishOp(status); + this->Op5::FinishOp(status); + this->Op6::FinishOp(status); + saved_status_ = *status; + if (RunInterceptorsPostRecv()) { + *tag = return_tag_; + g_core_codegen_interface->grpc_call_unref(call_.call()); + return true; + } + // Interceptors are going to be run, so we can't return the tag just yet. + // After the interceptors are run, ContinueFinalizeResultAfterInterception + return false; + } + + void set_output_tag(void* return_tag) { return_tag_ = return_tag; } + + void* cq_tag() override { return cq_tag_; } + + /// set_cq_tag is used to provide a different core CQ tag than "this". + /// This is used for callback-based tags, where the core tag is the core + /// callback function. It does not change the use or behavior of any other + /// function (such as FinalizeResult) + void set_cq_tag(void* cq_tag) { cq_tag_ = cq_tag; } + + // This will be called while interceptors are run if the RPC is a hijacked + // RPC. This should set hijacking state for each of the ops. + void SetHijackingState() override { + this->Op1::SetHijackingState(&interceptor_methods_); + this->Op2::SetHijackingState(&interceptor_methods_); + this->Op3::SetHijackingState(&interceptor_methods_); + this->Op4::SetHijackingState(&interceptor_methods_); + this->Op5::SetHijackingState(&interceptor_methods_); + this->Op6::SetHijackingState(&interceptor_methods_); + } + + // Should be called after interceptors are done running + void ContinueFillOpsAfterInterception() override { + static const size_t MAX_OPS = 6; + grpc_op ops[MAX_OPS]; + size_t nops = 0; + this->Op1::AddOp(ops, &nops); + this->Op2::AddOp(ops, &nops); + this->Op3::AddOp(ops, &nops); + this->Op4::AddOp(ops, &nops); + this->Op5::AddOp(ops, &nops); + this->Op6::AddOp(ops, &nops); + GPR_CODEGEN_ASSERT(GRPC_CALL_OK == + g_core_codegen_interface->grpc_call_start_batch( + call_.call(), ops, nops, cq_tag(), nullptr)); + } + + // Should be called after interceptors are done running on the finalize result + // path + void ContinueFinalizeResultAfterInterception() override { + done_intercepting_ = true; + GPR_CODEGEN_ASSERT(GRPC_CALL_OK == + g_core_codegen_interface->grpc_call_start_batch( + call_.call(), nullptr, 0, cq_tag(), nullptr)); + } + + private: + // Returns true if no interceptors need to be run + bool RunInterceptors() { + interceptor_methods_.ClearState(); + interceptor_methods_.SetCallOpSetInterface(this); + interceptor_methods_.SetCall(&call_); + this->Op1::SetInterceptionHookPoint(&interceptor_methods_); + this->Op2::SetInterceptionHookPoint(&interceptor_methods_); + this->Op3::SetInterceptionHookPoint(&interceptor_methods_); + this->Op4::SetInterceptionHookPoint(&interceptor_methods_); + this->Op5::SetInterceptionHookPoint(&interceptor_methods_); + this->Op6::SetInterceptionHookPoint(&interceptor_methods_); + return interceptor_methods_.RunInterceptors(); + } + // Returns true if no interceptors need to be run + bool RunInterceptorsPostRecv() { + // Call and OpSet had already been set on the set state. + // SetReverse also clears previously set hook points + interceptor_methods_.SetReverse(); + this->Op1::SetFinishInterceptionHookPoint(&interceptor_methods_); + this->Op2::SetFinishInterceptionHookPoint(&interceptor_methods_); + this->Op3::SetFinishInterceptionHookPoint(&interceptor_methods_); + this->Op4::SetFinishInterceptionHookPoint(&interceptor_methods_); + this->Op5::SetFinishInterceptionHookPoint(&interceptor_methods_); + this->Op6::SetFinishInterceptionHookPoint(&interceptor_methods_); + return interceptor_methods_.RunInterceptors(); + } + + void* cq_tag_; + void* return_tag_; + Call call_; + bool done_intercepting_ = false; + InterceptorBatchMethodsImpl interceptor_methods_; + bool saved_status_; +}; + +} // namespace internal +} // namespace grpc + +#endif // GRPCPP_IMPL_CODEGEN_CALL_OP_SET_H diff --git a/include/grpcpp/impl/codegen/call_op_set_interface.h b/include/grpcpp/impl/codegen/call_op_set_interface.h new file mode 100644 index 00000000000..815227a2996 --- /dev/null +++ b/include/grpcpp/impl/codegen/call_op_set_interface.h @@ -0,0 +1,59 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPCPP_IMPL_CODEGEN_CALL_OP_SET_INTERFACE_H +#define GRPCPP_IMPL_CODEGEN_CALL_OP_SET_INTERFACE_H + +#include + +namespace grpc { +namespace internal { + +class Call; + +/// An abstract collection of call ops, used to generate the +/// grpc_call_op structure to pass down to the lower layers, +/// and as it is-a CompletionQueueTag, also massages the final +/// completion into the correct form for consumption in the C++ +/// API. +class CallOpSetInterface : public CompletionQueueTag { + public: + /// Fills in grpc_op, starting from ops[*nops] and moving + /// upwards. + virtual void FillOps(internal::Call* call) = 0; + + /// Get the tag to be used at the core completion queue. Generally, the + /// value of cq_tag will be "this". However, it can be overridden if we + /// want core to process the tag differently (e.g., as a core callback) + virtual void* cq_tag() = 0; + + // This will be called while interceptors are run if the RPC is a hijacked + // RPC. This should set hijacking state for each of the ops. + virtual void SetHijackingState() = 0; + + // Should be called after interceptors are done running + virtual void ContinueFillOpsAfterInterception() = 0; + + // Should be called after interceptors are done running on the finalize result + // path + virtual void ContinueFinalizeResultAfterInterception() = 0; +}; +} // namespace internal +} // namespace grpc + +#endif // GRPCPP_IMPL_CODEGEN_CALL_OP_SET_INTERFACE_H diff --git a/include/grpcpp/impl/codegen/callback_common.h b/include/grpcpp/impl/codegen/callback_common.h index a9835973acc..eba9ec6edc4 100644 --- a/include/grpcpp/impl/codegen/callback_common.h +++ b/include/grpcpp/impl/codegen/callback_common.h @@ -94,7 +94,10 @@ class CallbackWithStatusTag void Run(bool ok) { void* ignored = ops_; - GPR_CODEGEN_ASSERT(ops_->FinalizeResult(&ignored, &ok)); + if (!ops_->FinalizeResult(&ignored, &ok)) { + // The tag was swallowed + return; + } GPR_CODEGEN_ASSERT(ignored == ops_); // Last use of func_ or status_, so ok to move them out diff --git a/include/grpcpp/impl/codegen/channel_interface.h b/include/grpcpp/impl/codegen/channel_interface.h index b257acc1abd..6fd1dd1d9be 100644 --- a/include/grpcpp/impl/codegen/channel_interface.h +++ b/include/grpcpp/impl/codegen/channel_interface.h @@ -20,6 +20,7 @@ #define GRPCPP_IMPL_CODEGEN_CHANNEL_INTERFACE_H #include +#include #include #include @@ -51,6 +52,7 @@ template class ClientAsyncReaderWriterFactory; template class ClientAsyncResponseReaderFactory; +class InterceptedChannel; } // namespace internal /// Codegen interface for \a grpc::Channel. @@ -108,6 +110,7 @@ class ChannelInterface { template friend class ::grpc::internal::CallbackUnaryCallImpl; friend class ::grpc::internal::RpcMethod; + friend class ::grpc::internal::InterceptedChannel; virtual internal::Call CreateCall(const internal::RpcMethod& method, ClientContext* context, CompletionQueue* cq) = 0; @@ -120,6 +123,20 @@ class ChannelInterface { virtual bool WaitForStateChangeImpl(grpc_connectivity_state last_observed, gpr_timespec deadline) = 0; + // EXPERIMENTAL + // This is needed to keep codegen_test_minimal happy. InterceptedChannel needs + // to make use of this but can't directly call Channel's implementation + // because of the test. + // Returns an empty Call object (rather than being pure) since this is a new + // method and adding a new pure method to an interface would be a breaking + // change (even though this is private and non-API) + virtual internal::Call CreateCallInternal(const internal::RpcMethod& method, + ClientContext* context, + CompletionQueue* cq, + int interceptor_pos) { + return internal::Call(); + } + // EXPERIMENTAL // A method to get the callbackable completion queue associated with this // channel. If the return value is nullptr, this channel doesn't support diff --git a/include/grpcpp/impl/codegen/client_callback.h b/include/grpcpp/impl/codegen/client_callback.h index 4d4faea0635..ecb00a07690 100644 --- a/include/grpcpp/impl/codegen/client_callback.h +++ b/include/grpcpp/impl/codegen/client_callback.h @@ -77,7 +77,7 @@ class CallbackUnaryCallImpl { tag->force_run(s); return; } - ops->SendInitialMetadata(context->send_initial_metadata_, + ops->SendInitialMetadata(&context->send_initial_metadata_, context->initial_metadata_flags()); ops->RecvInitialMetadata(context); ops->RecvMessage(result); diff --git a/include/grpcpp/impl/codegen/client_context.h b/include/grpcpp/impl/codegen/client_context.h index 24f5c431ceb..f53b744dcf4 100644 --- a/include/grpcpp/impl/codegen/client_context.h +++ b/include/grpcpp/impl/codegen/client_context.h @@ -41,6 +41,7 @@ #include #include +#include #include #include #include @@ -402,6 +403,17 @@ class ClientContext { grpc_call* call() const { return call_; } void set_call(grpc_call* call, const std::shared_ptr& channel); + experimental::ClientRpcInfo* set_client_rpc_info( + const char* method, grpc::ChannelInterface* channel, + const std::vector< + std::unique_ptr>& + creators, + size_t interceptor_pos) { + rpc_info_ = experimental::ClientRpcInfo(this, method, channel); + rpc_info_.RegisterInterceptors(creators, interceptor_pos); + return &rpc_info_; + } + uint32_t initial_metadata_flags() const { return (idempotent_ ? GRPC_INITIAL_METADATA_IDEMPOTENT_REQUEST : 0) | (wait_for_ready_ ? GRPC_INITIAL_METADATA_WAIT_FOR_READY : 0) | @@ -439,6 +451,8 @@ class ClientContext { bool initial_metadata_corked_; grpc::string debug_error_string_; + + experimental::ClientRpcInfo rpc_info_; }; } // namespace grpc diff --git a/include/grpcpp/impl/codegen/client_interceptor.h b/include/grpcpp/impl/codegen/client_interceptor.h index f460c5ac0c0..00113f04aaa 100644 --- a/include/grpcpp/impl/codegen/client_interceptor.h +++ b/include/grpcpp/impl/codegen/client_interceptor.h @@ -19,23 +19,76 @@ #ifndef GRPCPP_IMPL_CODEGEN_CLIENT_INTERCEPTOR_H #define GRPCPP_IMPL_CODEGEN_CLIENT_INTERCEPTOR_H +#include + +#include #include +#include namespace grpc { -namespace experimental { -class ClientInterceptor { - public: - virtual ~ClientInterceptor() {} - virtual void Intercept(InterceptorBatchMethods* methods) = 0; -}; +class ClientContext; +class Channel; -class ClientRpcInfo {}; +namespace internal { +class InterceptorBatchMethodsImpl; +} + +namespace experimental { +class ClientRpcInfo; class ClientInterceptorFactoryInterface { public: virtual ~ClientInterceptorFactoryInterface() {} - virtual ClientInterceptor* CreateClientInterceptor(ClientRpcInfo* info) = 0; + virtual Interceptor* CreateClientInterceptor(ClientRpcInfo* info) = 0; +}; + +class ClientRpcInfo { + public: + ClientRpcInfo() {} + + ~ClientRpcInfo(){}; + + ClientRpcInfo(const ClientRpcInfo&) = delete; + ClientRpcInfo(ClientRpcInfo&&) = default; + ClientRpcInfo& operator=(ClientRpcInfo&&) = default; + + // Getter methods + const char* method() { return method_; } + ChannelInterface* channel() { return channel_; } + grpc::ClientContext* client_context() { return ctx_; } + + private: + ClientRpcInfo(grpc::ClientContext* ctx, const char* method, + grpc::ChannelInterface* channel) + : ctx_(ctx), method_(method), channel_(channel) {} + // Runs interceptor at pos \a pos. + void RunInterceptor( + experimental::InterceptorBatchMethods* interceptor_methods, size_t pos) { + GPR_CODEGEN_ASSERT(pos < interceptors_.size()); + interceptors_[pos]->Intercept(interceptor_methods); + } + + void RegisterInterceptors( + const std::vector>& creators, + int interceptor_pos) { + for (auto it = creators.begin() + interceptor_pos; it != creators.end(); + ++it) { + interceptors_.push_back(std::unique_ptr( + (*it)->CreateClientInterceptor(this))); + } + } + + grpc::ClientContext* ctx_ = nullptr; + const char* method_ = nullptr; + grpc::ChannelInterface* channel_ = nullptr; + std::vector> interceptors_; + bool hijacked_ = false; + size_t hijacked_interceptor_ = 0; + + friend class internal::InterceptorBatchMethodsImpl; + friend class grpc::ClientContext; }; } // namespace experimental diff --git a/include/grpcpp/impl/codegen/client_unary_call.h b/include/grpcpp/impl/codegen/client_unary_call.h index e4e8364e073..b1c80764f23 100644 --- a/include/grpcpp/impl/codegen/client_unary_call.h +++ b/include/grpcpp/impl/codegen/client_unary_call.h @@ -61,7 +61,7 @@ class BlockingUnaryCallImpl { if (!status_.ok()) { return; } - ops.SendInitialMetadata(context->send_initial_metadata_, + ops.SendInitialMetadata(&context->send_initial_metadata_, context->initial_metadata_flags()); ops.RecvInitialMetadata(context); ops.RecvMessage(result); diff --git a/include/grpcpp/impl/codegen/completion_queue.h b/include/grpcpp/impl/codegen/completion_queue.h index 4ca19f46724..5eef2c281f3 100644 --- a/include/grpcpp/impl/codegen/completion_queue.h +++ b/include/grpcpp/impl/codegen/completion_queue.h @@ -300,14 +300,17 @@ class CompletionQueue : private GrpcLibraryCodegen { bool Pluck(internal::CompletionQueueTag* tag) { auto deadline = g_core_codegen_interface->gpr_inf_future(GPR_CLOCK_REALTIME); - auto ev = g_core_codegen_interface->grpc_completion_queue_pluck( - cq_, tag, deadline, nullptr); - bool ok = ev.success != 0; - void* ignored = tag; - GPR_CODEGEN_ASSERT(tag->FinalizeResult(&ignored, &ok)); - GPR_CODEGEN_ASSERT(ignored == tag); - // Ignore mutations by FinalizeResult: Pluck returns the C API status - return ev.success != 0; + while (true) { + auto ev = g_core_codegen_interface->grpc_completion_queue_pluck( + cq_, tag, deadline, nullptr); + bool ok = ev.success != 0; + void* ignored = tag; + if (tag->FinalizeResult(&ignored, &ok)) { + GPR_CODEGEN_ASSERT(ignored == tag); + // Ignore mutations by FinalizeResult: Pluck returns the C API status + return ev.success != 0; + } + } } /// Performs a single polling pluck on \a tag. diff --git a/include/grpcpp/impl/codegen/core_codegen.h b/include/grpcpp/impl/codegen/core_codegen.h index e9df96bf048..6ef184d01ab 100644 --- a/include/grpcpp/impl/codegen/core_codegen.h +++ b/include/grpcpp/impl/codegen/core_codegen.h @@ -63,6 +63,9 @@ class CoreCodegen final : public CoreCodegenInterface { void gpr_cv_signal(gpr_cv* cv) override; void gpr_cv_broadcast(gpr_cv* cv) override; + grpc_call_error grpc_call_start_batch(grpc_call* call, const grpc_op* ops, + size_t nops, void* tag, + void* reserved) override; grpc_call_error grpc_call_cancel_with_status(grpc_call* call, grpc_status_code status, const char* description, diff --git a/include/grpcpp/impl/codegen/core_codegen_interface.h b/include/grpcpp/impl/codegen/core_codegen_interface.h index 1167a188a24..25e3abccca5 100644 --- a/include/grpcpp/impl/codegen/core_codegen_interface.h +++ b/include/grpcpp/impl/codegen/core_codegen_interface.h @@ -100,6 +100,9 @@ class CoreCodegenInterface { virtual grpc_slice grpc_slice_new_with_len(void* p, size_t len, void (*destroy)(void*, size_t)) = 0; + virtual grpc_call_error grpc_call_start_batch(grpc_call* call, + const grpc_op* ops, size_t nops, + void* tag, void* reserved) = 0; virtual grpc_call_error grpc_call_cancel_with_status(grpc_call* call, grpc_status_code status, const char* description, diff --git a/include/grpcpp/impl/codegen/intercepted_channel.h b/include/grpcpp/impl/codegen/intercepted_channel.h new file mode 100644 index 00000000000..612e56d8620 --- /dev/null +++ b/include/grpcpp/impl/codegen/intercepted_channel.h @@ -0,0 +1,80 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPCPP_IMPL_CODEGEN_INTERCEPTED_CHANNEL_H +#define GRPCPP_IMPL_CODEGEN_INTERCEPTED_CHANNEL_H + +#include + +namespace grpc { + +namespace internal { + +class InterceptorBatchMethodsImpl; + +/// An InterceptedChannel is available to client Interceptors. An +/// InterceptedChannel is unique to an interceptor, and when an RPC is started +/// on this channel, only those interceptors that come after this interceptor +/// see the RPC. +class InterceptedChannel : public ChannelInterface { + public: + virtual ~InterceptedChannel() { channel_ = nullptr; } + + /// Get the current channel state. If the channel is in IDLE and + /// \a try_to_connect is set to true, try to connect. + grpc_connectivity_state GetState(bool try_to_connect) override { + return channel_->GetState(try_to_connect); + } + + private: + InterceptedChannel(ChannelInterface* channel, int pos) + : channel_(channel), interceptor_pos_(pos) {} + + Call CreateCall(const RpcMethod& method, ClientContext* context, + CompletionQueue* cq) override { + return channel_->CreateCallInternal(method, context, cq, interceptor_pos_); + } + + void PerformOpsOnCall(CallOpSetInterface* ops, Call* call) override { + return channel_->PerformOpsOnCall(ops, call); + } + void* RegisterMethod(const char* method) override { + return channel_->RegisterMethod(method); + } + + void NotifyOnStateChangeImpl(grpc_connectivity_state last_observed, + gpr_timespec deadline, CompletionQueue* cq, + void* tag) override { + return channel_->NotifyOnStateChangeImpl(last_observed, deadline, cq, tag); + } + bool WaitForStateChangeImpl(grpc_connectivity_state last_observed, + gpr_timespec deadline) override { + return channel_->WaitForStateChangeImpl(last_observed, deadline); + } + + CompletionQueue* CallbackCQ() override { return channel_->CallbackCQ(); } + + ChannelInterface* channel_; + int interceptor_pos_; + + friend class InterceptorBatchMethodsImpl; +}; +} // namespace internal +} // namespace grpc + +#endif // GRPCPP_IMPL_CODEGEN_INTERCEPTED_CHANNEL_H diff --git a/include/grpcpp/impl/codegen/interceptor.h b/include/grpcpp/impl/codegen/interceptor.h index 6402a3a9466..cdd34b80d17 100644 --- a/include/grpcpp/impl/codegen/interceptor.h +++ b/include/grpcpp/impl/codegen/interceptor.h @@ -19,7 +19,17 @@ #ifndef GRPCPP_IMPL_CODEGEN_INTERCEPTOR_H #define GRPCPP_IMPL_CODEGEN_INTERCEPTOR_H +#include +#include +#include +#include +#include +#include + namespace grpc { + +class Status; + namespace experimental { class InterceptedMessage { public: @@ -35,6 +45,7 @@ enum class InterceptionHookPoints { PRE_SEND_INITIAL_METADATA, PRE_SEND_MESSAGE, PRE_SEND_STATUS /* server only */, + PRE_SEND_CLOSE /* client only */, /* The following three are for hijacked clients only and can only be registered by the global interceptor */ PRE_RECV_INITIAL_METADATA, @@ -50,7 +61,7 @@ enum class InterceptionHookPoints { class InterceptorBatchMethods { public: - virtual ~InterceptorBatchMethods(); + virtual ~InterceptorBatchMethods(){}; // Queries to check whether the current batch has an interception hook point // of type \a type virtual bool QueryInterceptionHookPoint(InterceptionHookPoints type) = 0; @@ -60,7 +71,53 @@ class InterceptorBatchMethods { // Calling this indicates that the interceptor has hijacked the RPC (only // valid if the batch contains send_initial_metadata on the client side) virtual void Hijack() = 0; + + // Returns a modifable ByteBuffer holding serialized form of the message to be + // sent + virtual ByteBuffer* GetSendMessage() = 0; + + // Returns a modifiable multimap of the initial metadata to be sent + virtual std::multimap* + GetSendInitialMetadata() = 0; + + // Returns the status to be sent + virtual Status GetSendStatus() = 0; + + // Modifies the status with \a status + virtual void ModifySendStatus(const Status& status) = 0; + + // Returns a modifiable multimap of the trailing metadata to be sent + virtual std::multimap* + GetSendTrailingMetadata() = 0; + + // Returns a pointer to the modifiable received message. Note that the message + // is already deserialized + virtual void* GetRecvMessage() = 0; + + // Returns a modifiable multimap of the received initial metadata + virtual std::multimap* + GetRecvInitialMetadata() = 0; + + // Returns a modifiable view of the received status + virtual Status* GetRecvStatus() = 0; + + // Returns a modifiable multimap of the received trailing metadata + virtual std::multimap* + GetRecvTrailingMetadata() = 0; + + // Gets an intercepted channel. When a call is started on this interceptor, + // only interceptors after the current interceptor are created from the + // factory objects registered with the channel. + virtual std::unique_ptr GetInterceptedChannel() = 0; }; + +class Interceptor { + public: + virtual ~Interceptor() {} + + virtual void Intercept(InterceptorBatchMethods* methods) = 0; +}; + } // namespace experimental } // namespace grpc diff --git a/include/grpcpp/impl/codegen/interceptor_common.h b/include/grpcpp/impl/codegen/interceptor_common.h new file mode 100644 index 00000000000..cf564977f69 --- /dev/null +++ b/include/grpcpp/impl/codegen/interceptor_common.h @@ -0,0 +1,383 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPCPP_IMPL_CODEGEN_INTERCEPTOR_COMMON_H +#define GRPCPP_IMPL_CODEGEN_INTERCEPTOR_COMMON_H + +#include +#include + +#include + +namespace grpc { +namespace internal { + +/// Internal methods for setting the state +class InternalInterceptorBatchMethods + : public experimental::InterceptorBatchMethods { + public: + virtual ~InternalInterceptorBatchMethods() {} + + virtual void AddInterceptionHookPoint( + experimental::InterceptionHookPoints type) = 0; + + virtual void SetSendMessage(ByteBuffer* buf) = 0; + + virtual void SetSendInitialMetadata( + std::multimap* metadata) = 0; + + virtual void SetSendStatus(grpc_status_code* code, + grpc::string* error_details, + grpc::string* error_message) = 0; + + virtual void SetSendTrailingMetadata( + std::multimap* metadata) = 0; + + virtual void SetRecvMessage(void* message) = 0; + + virtual void SetRecvInitialMetadata(MetadataMap* map) = 0; + + virtual void SetRecvStatus(Status* status) = 0; + + virtual void SetRecvTrailingMetadata(MetadataMap* map) = 0; +}; + +class InterceptorBatchMethodsImpl : public InternalInterceptorBatchMethods { + public: + InterceptorBatchMethodsImpl() { + for (auto i = static_cast(0); + i < experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS; + i = static_cast( + static_cast(i) + 1)) { + hooks_[static_cast(i)] = false; + } + } + + ~InterceptorBatchMethodsImpl() {} + + bool QueryInterceptionHookPoint( + experimental::InterceptionHookPoints type) override { + return hooks_[static_cast(type)]; + } + + void Proceed() override { /* fill this */ + if (call_->client_rpc_info() != nullptr) { + return ProceedClient(); + } + GPR_CODEGEN_ASSERT(call_->server_rpc_info() != nullptr); + ProceedServer(); + } + + void Hijack() override { + // Only the client can hijack when sending down initial metadata + GPR_CODEGEN_ASSERT(!reverse_ && ops_ != nullptr && + call_->client_rpc_info() != nullptr); + // It is illegal to call Hijack twice + GPR_CODEGEN_ASSERT(!ran_hijacking_interceptor_); + auto* rpc_info = call_->client_rpc_info(); + rpc_info->hijacked_ = true; + rpc_info->hijacked_interceptor_ = current_interceptor_index_; + ClearHookPoints(); + ops_->SetHijackingState(); + ran_hijacking_interceptor_ = true; + rpc_info->RunInterceptor(this, current_interceptor_index_); + } + + void AddInterceptionHookPoint( + experimental::InterceptionHookPoints type) override { + hooks_[static_cast(type)] = true; + } + + ByteBuffer* GetSendMessage() override { return send_message_; } + + std::multimap* GetSendInitialMetadata() override { + return send_initial_metadata_; + } + + Status GetSendStatus() override { + return Status(static_cast(*code_), *error_message_, + *error_details_); + } + + void ModifySendStatus(const Status& status) override { + *code_ = static_cast(status.error_code()); + *error_details_ = status.error_details(); + *error_message_ = status.error_message(); + } + + std::multimap* GetSendTrailingMetadata() + override { + return send_trailing_metadata_; + } + + void* GetRecvMessage() override { return recv_message_; } + + std::multimap* GetRecvInitialMetadata() + override { + return recv_initial_metadata_->map(); + } + + Status* GetRecvStatus() override { return recv_status_; } + + std::multimap* GetRecvTrailingMetadata() + override { + return recv_trailing_metadata_->map(); + } + + void SetSendMessage(ByteBuffer* buf) override { send_message_ = buf; } + + void SetSendInitialMetadata( + std::multimap* metadata) override { + send_initial_metadata_ = metadata; + } + + void SetSendStatus(grpc_status_code* code, grpc::string* error_details, + grpc::string* error_message) override { + code_ = code; + error_details_ = error_details; + error_message_ = error_message; + } + + void SetSendTrailingMetadata( + std::multimap* metadata) override { + send_trailing_metadata_ = metadata; + } + + void SetRecvMessage(void* message) override { recv_message_ = message; } + + void SetRecvInitialMetadata(MetadataMap* map) override { + recv_initial_metadata_ = map; + } + + void SetRecvStatus(Status* status) override { recv_status_ = status; } + + void SetRecvTrailingMetadata(MetadataMap* map) override { + recv_trailing_metadata_ = map; + } + + std::unique_ptr GetInterceptedChannel() override { + auto* info = call_->client_rpc_info(); + if (info == nullptr) { + return std::unique_ptr(nullptr); + } + // The intercepted channel starts from the interceptor just after the + // current interceptor + return std::unique_ptr(new InterceptedChannel( + info->channel(), current_interceptor_index_ + 1)); + } + + // Clears all state + void ClearState() { + reverse_ = false; + ran_hijacking_interceptor_ = false; + ClearHookPoints(); + } + + // Prepares for Post_recv operations + void SetReverse() { + reverse_ = true; + ran_hijacking_interceptor_ = false; + ClearHookPoints(); + } + + // This needs to be set before interceptors are run + void SetCall(Call* call) { call_ = call; } + + // This needs to be set before interceptors are run using RunInterceptors(). + // Alternatively, RunInterceptors(std::function f) can be used. + void SetCallOpSetInterface(CallOpSetInterface* ops) { ops_ = ops; } + + // Returns true if no interceptors are run. This should be used only by + // subclasses of CallOpSetInterface. SetCall and SetCallOpSetInterface should + // have been called before this. After all the interceptors are done running, + // either ContinueFillOpsAfterInterception or + // ContinueFinalizeOpsAfterInterception will be called. Note that neither of + // them is invoked if there were no interceptors registered. + bool RunInterceptors() { + GPR_CODEGEN_ASSERT(ops_); + auto* client_rpc_info = call_->client_rpc_info(); + if (client_rpc_info != nullptr) { + if (client_rpc_info->interceptors_.size() == 0) { + return true; + } else { + RunClientInterceptors(); + return false; + } + } + + auto* server_rpc_info = call_->server_rpc_info(); + if (server_rpc_info == nullptr || + server_rpc_info->interceptors_.size() == 0) { + return true; + } + RunServerInterceptors(); + return false; + } + + // Returns true if no interceptors are run. Returns false otherwise if there + // are interceptors registered. After the interceptors are done running \a f + // will be invoked. This is to be used only by BaseAsyncRequest and + // SyncRequest. + bool RunInterceptors(std::function f) { + // This is used only by the server for initial call request + GPR_CODEGEN_ASSERT(reverse_ == true); + GPR_CODEGEN_ASSERT(call_->client_rpc_info() == nullptr); + auto* server_rpc_info = call_->server_rpc_info(); + if (server_rpc_info == nullptr || + server_rpc_info->interceptors_.size() == 0) { + return true; + } + callback_ = std::move(f); + RunServerInterceptors(); + return false; + } + + private: + void RunClientInterceptors() { + auto* rpc_info = call_->client_rpc_info(); + if (!reverse_) { + current_interceptor_index_ = 0; + } else { + if (rpc_info->hijacked_) { + current_interceptor_index_ = rpc_info->hijacked_interceptor_; + } else { + current_interceptor_index_ = rpc_info->interceptors_.size() - 1; + } + } + rpc_info->RunInterceptor(this, current_interceptor_index_); + } + + void RunServerInterceptors() { + auto* rpc_info = call_->server_rpc_info(); + if (!reverse_) { + current_interceptor_index_ = 0; + } else { + current_interceptor_index_ = rpc_info->interceptors_.size() - 1; + } + rpc_info->RunInterceptor(this, current_interceptor_index_); + } + + void ProceedClient() { + auto* rpc_info = call_->client_rpc_info(); + if (rpc_info->hijacked_ && !reverse_ && + current_interceptor_index_ == rpc_info->hijacked_interceptor_ && + !ran_hijacking_interceptor_) { + // We now need to provide hijacked recv ops to this interceptor + ClearHookPoints(); + ops_->SetHijackingState(); + ran_hijacking_interceptor_ = true; + rpc_info->RunInterceptor(this, current_interceptor_index_); + return; + } + if (!reverse_) { + current_interceptor_index_++; + // We are going down the stack of interceptors + if (current_interceptor_index_ < rpc_info->interceptors_.size()) { + if (rpc_info->hijacked_ && + current_interceptor_index_ > rpc_info->hijacked_interceptor_) { + // This is a hijacked RPC and we are done with hijacking + ops_->ContinueFillOpsAfterInterception(); + } else { + rpc_info->RunInterceptor(this, current_interceptor_index_); + } + } else { + // we are done running all the interceptors without any hijacking + ops_->ContinueFillOpsAfterInterception(); + } + } else { + // We are going up the stack of interceptors + if (current_interceptor_index_ > 0) { + // Continue running interceptors + current_interceptor_index_--; + rpc_info->RunInterceptor(this, current_interceptor_index_); + } else { + // we are done running all the interceptors without any hijacking + ops_->ContinueFinalizeResultAfterInterception(); + } + } + } + + void ProceedServer() { + auto* rpc_info = call_->server_rpc_info(); + if (!reverse_) { + current_interceptor_index_++; + if (current_interceptor_index_ < rpc_info->interceptors_.size()) { + return rpc_info->RunInterceptor(this, current_interceptor_index_); + } else if (ops_) { + return ops_->ContinueFillOpsAfterInterception(); + } + } else { + // We are going up the stack of interceptors + if (current_interceptor_index_ > 0) { + // Continue running interceptors + current_interceptor_index_--; + return rpc_info->RunInterceptor(this, current_interceptor_index_); + } else if (ops_) { + return ops_->ContinueFinalizeResultAfterInterception(); + } + } + GPR_CODEGEN_ASSERT(callback_); + callback_(); + } + + void ClearHookPoints() { + for (auto i = static_cast(0); + i < experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS; + i = static_cast( + static_cast(i) + 1)) { + hooks_[static_cast(i)] = false; + } + } + + std::array( + experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS)> + hooks_; + + size_t current_interceptor_index_ = 0; // Current iterator + bool reverse_ = false; + bool ran_hijacking_interceptor_ = false; + Call* call_ = nullptr; // The Call object is present along with CallOpSet + // object/callback + CallOpSetInterface* ops_ = nullptr; + std::function callback_; + + ByteBuffer* send_message_ = nullptr; + + std::multimap* send_initial_metadata_; + + grpc_status_code* code_ = nullptr; + grpc::string* error_details_ = nullptr; + grpc::string* error_message_ = nullptr; + Status send_status_; + + std::multimap* send_trailing_metadata_ = nullptr; + + void* recv_message_ = nullptr; + + MetadataMap* recv_initial_metadata_ = nullptr; + + Status* recv_status_ = nullptr; + + MetadataMap* recv_trailing_metadata_ = nullptr; +}; + +} // namespace internal +} // namespace grpc + +#endif // GRPCPP_IMPL_CODEGEN_INTERCEPTOR_COMMON_H diff --git a/include/grpcpp/impl/codegen/metadata_map.h b/include/grpcpp/impl/codegen/metadata_map.h index 5e062a50f83..0bba3ed4e36 100644 --- a/include/grpcpp/impl/codegen/metadata_map.h +++ b/include/grpcpp/impl/codegen/metadata_map.h @@ -19,6 +19,8 @@ #ifndef GRPCPP_IMPL_CODEGEN_METADATA_MAP_H #define GRPCPP_IMPL_CODEGEN_METADATA_MAP_H +#include + #include #include diff --git a/include/grpcpp/impl/codegen/method_handler_impl.h b/include/grpcpp/impl/codegen/method_handler_impl.h index 53117f941b4..4f02e3e39b9 100644 --- a/include/grpcpp/impl/codegen/method_handler_impl.h +++ b/include/grpcpp/impl/codegen/method_handler_impl.h @@ -59,21 +59,21 @@ class RpcMethodHandler : public MethodHandler { : func_(func), service_(service) {} void RunHandler(const HandlerParameter& param) final { - RequestType req; - Status status = SerializationTraits::Deserialize( - param.request.bbuf_ptr(), &req); ResponseType rsp; + Status status = param.status; if (status.ok()) { - status = CatchingFunctionHandler([this, ¶m, &req, &rsp] { - return func_(service_, param.server_context, &req, &rsp); + status = CatchingFunctionHandler([this, ¶m, &rsp] { + return func_(service_, param.server_context, + static_cast(param.request), &rsp); }); + delete static_cast(param.request); } GPR_CODEGEN_ASSERT(!param.server_context->sent_initial_metadata_); CallOpSet ops; - ops.SendInitialMetadata(param.server_context->initial_metadata_, + ops.SendInitialMetadata(¶m.server_context->initial_metadata_, param.server_context->initial_metadata_flags()); if (param.server_context->compression_level_set()) { ops.set_compression_level(param.server_context->compression_level()); @@ -81,11 +81,24 @@ class RpcMethodHandler : public MethodHandler { if (status.ok()) { status = ops.SendMessage(rsp); } - ops.ServerSendStatus(param.server_context->trailing_metadata_, status); + ops.ServerSendStatus(¶m.server_context->trailing_metadata_, status); param.call->PerformOps(&ops); param.call->cq()->Pluck(&ops); } + void* Deserialize(grpc_byte_buffer* req, Status* status) final { + ByteBuffer buf; + buf.set_buffer(req); + auto* request = new RequestType(); + *status = SerializationTraits::Deserialize(&buf, request); + buf.Release(); + if (status->ok()) { + return request; + } + delete request; + return nullptr; + } + private: /// Application provided rpc handler function. std::function ops; if (!param.server_context->sent_initial_metadata_) { - ops.SendInitialMetadata(param.server_context->initial_metadata_, + ops.SendInitialMetadata(¶m.server_context->initial_metadata_, param.server_context->initial_metadata_flags()); if (param.server_context->compression_level_set()) { ops.set_compression_level(param.server_context->compression_level()); @@ -126,7 +139,7 @@ class ClientStreamingHandler : public MethodHandler { if (status.ok()) { status = ops.SendMessage(rsp); } - ops.ServerSendStatus(param.server_context->trailing_metadata_, status); + ops.ServerSendStatus(¶m.server_context->trailing_metadata_, status); param.call->PerformOps(&ops); param.call->cq()->Pluck(&ops); } @@ -150,26 +163,25 @@ class ServerStreamingHandler : public MethodHandler { : func_(func), service_(service) {} void RunHandler(const HandlerParameter& param) final { - RequestType req; - Status status = SerializationTraits::Deserialize( - param.request.bbuf_ptr(), &req); - + Status status = param.status; if (status.ok()) { ServerWriter writer(param.call, param.server_context); - status = CatchingFunctionHandler([this, ¶m, &req, &writer] { - return func_(service_, param.server_context, &req, &writer); + status = CatchingFunctionHandler([this, ¶m, &writer] { + return func_(service_, param.server_context, + static_cast(param.request), &writer); }); + delete static_cast(param.request); } CallOpSet ops; if (!param.server_context->sent_initial_metadata_) { - ops.SendInitialMetadata(param.server_context->initial_metadata_, + ops.SendInitialMetadata(¶m.server_context->initial_metadata_, param.server_context->initial_metadata_flags()); if (param.server_context->compression_level_set()) { ops.set_compression_level(param.server_context->compression_level()); } } - ops.ServerSendStatus(param.server_context->trailing_metadata_, status); + ops.ServerSendStatus(¶m.server_context->trailing_metadata_, status); param.call->PerformOps(&ops); if (param.server_context->has_pending_ops_) { param.call->cq()->Pluck(¶m.server_context->pending_ops_); @@ -177,6 +189,19 @@ class ServerStreamingHandler : public MethodHandler { param.call->cq()->Pluck(&ops); } + void* Deserialize(grpc_byte_buffer* req, Status* status) final { + ByteBuffer buf; + buf.set_buffer(req); + auto* request = new RequestType(); + *status = SerializationTraits::Deserialize(&buf, request); + buf.Release(); + if (status->ok()) { + return request; + } + delete request; + return nullptr; + } + private: std::function*)> @@ -206,7 +231,7 @@ class TemplatedBidiStreamingHandler : public MethodHandler { CallOpSet ops; if (!param.server_context->sent_initial_metadata_) { - ops.SendInitialMetadata(param.server_context->initial_metadata_, + ops.SendInitialMetadata(¶m.server_context->initial_metadata_, param.server_context->initial_metadata_flags()); if (param.server_context->compression_level_set()) { ops.set_compression_level(param.server_context->compression_level()); @@ -218,7 +243,7 @@ class TemplatedBidiStreamingHandler : public MethodHandler { "Service did not provide response message"); } } - ops.ServerSendStatus(param.server_context->trailing_metadata_, status); + ops.ServerSendStatus(¶m.server_context->trailing_metadata_, status); param.call->PerformOps(&ops); if (param.server_context->has_pending_ops_) { param.call->cq()->Pluck(¶m.server_context->pending_ops_); @@ -281,14 +306,14 @@ class ErrorMethodHandler : public MethodHandler { static void FillOps(ServerContext* context, T* ops) { Status status(code, ""); if (!context->sent_initial_metadata_) { - ops->SendInitialMetadata(context->initial_metadata_, + ops->SendInitialMetadata(&context->initial_metadata_, context->initial_metadata_flags()); if (context->compression_level_set()) { ops->set_compression_level(context->compression_level()); } context->sent_initial_metadata_ = true; } - ops->ServerSendStatus(context->trailing_metadata_, status); + ops->ServerSendStatus(&context->trailing_metadata_, status); } void RunHandler(const HandlerParameter& param) final { @@ -296,11 +321,14 @@ class ErrorMethodHandler : public MethodHandler { FillOps(param.server_context, &ops); param.call->PerformOps(&ops); param.call->cq()->Pluck(&ops); - // We also have to destroy any request payload in the handler parameter - ByteBuffer* payload = param.request.bbuf_ptr(); - if (payload != nullptr) { - payload->Clear(); + } + + void* Deserialize(grpc_byte_buffer* req, Status* status) final { + // We have to destroy any request payload + if (req != nullptr) { + g_core_codegen_interface->grpc_byte_buffer_destroy(req); } + return nullptr; } }; diff --git a/include/grpcpp/impl/codegen/rpc_service_method.h b/include/grpcpp/impl/codegen/rpc_service_method.h index 5cf88e216f9..44da2bd7689 100644 --- a/include/grpcpp/impl/codegen/rpc_service_method.h +++ b/include/grpcpp/impl/codegen/rpc_service_method.h @@ -40,17 +40,26 @@ class MethodHandler { public: virtual ~MethodHandler() {} struct HandlerParameter { - HandlerParameter(Call* c, ServerContext* context, grpc_byte_buffer* req) - : call(c), server_context(context) { - request.set_buffer(req); - } - ~HandlerParameter() { request.Release(); } + HandlerParameter(Call* c, ServerContext* context, void* req, + Status req_status) + : call(c), server_context(context), request(req), status(req_status) {} + ~HandlerParameter() {} Call* call; ServerContext* server_context; - // Handler required to destroy these contents - ByteBuffer request; + void* request; + Status status; }; virtual void RunHandler(const HandlerParameter& param) = 0; + + /* Returns a pointer to the deserialized request. \a status reflects the + result of deserialization. This pointer and the status should be filled in + a HandlerParameter and passed to RunHandler. It is illegal to access the + pointer after calling RunHandler. Ownership of the deserialized request is + retained by the handler. Returns nullptr if deserialization failed. */ + virtual void* Deserialize(grpc_byte_buffer* req, Status* status) { + GPR_CODEGEN_ASSERT(req == nullptr); + return nullptr; + } }; /// Server side rpc method class diff --git a/include/grpcpp/impl/codegen/server_context.h b/include/grpcpp/impl/codegen/server_context.h index d53c09aa1bf..7559fb3b346 100644 --- a/include/grpcpp/impl/codegen/server_context.h +++ b/include/grpcpp/impl/codegen/server_context.h @@ -26,11 +26,13 @@ #include #include +#include #include #include #include #include #include +#include #include #include @@ -285,6 +287,18 @@ class ServerContext { uint32_t initial_metadata_flags() const { return 0; } + experimental::ServerRpcInfo* set_server_rpc_info( + const char* method, + const std::vector< + std::unique_ptr>& + creators) { + if (creators.size() != 0) { + rpc_info_ = new experimental::ServerRpcInfo(this, method); + rpc_info_->RegisterInterceptors(creators); + } + return rpc_info_; + } + CompletionOp* completion_op_; bool has_notify_when_done_tag_; void* async_notify_when_done_tag_; @@ -306,6 +320,8 @@ class ServerContext { internal::CallOpSendMessage> pending_ops_; bool has_pending_ops_; + + experimental::ServerRpcInfo* rpc_info_ = nullptr; }; } // namespace grpc diff --git a/include/grpcpp/impl/codegen/server_interceptor.h b/include/grpcpp/impl/codegen/server_interceptor.h new file mode 100644 index 00000000000..c39e9a988d5 --- /dev/null +++ b/include/grpcpp/impl/codegen/server_interceptor.h @@ -0,0 +1,100 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPCPP_IMPL_CODEGEN_SERVER_INTERCEPTOR_H +#define GRPCPP_IMPL_CODEGEN_SERVER_INTERCEPTOR_H + +#include +#include + +#include +#include +#include + +namespace grpc { + +class ServerContext; + +namespace internal { +class InterceptorBatchMethodsImpl; +} + +namespace experimental { +class ServerRpcInfo; + +class ServerInterceptorFactoryInterface { + public: + virtual ~ServerInterceptorFactoryInterface() {} + virtual Interceptor* CreateServerInterceptor(ServerRpcInfo* info) = 0; +}; + +class ServerRpcInfo { + public: + ~ServerRpcInfo(){}; + + ServerRpcInfo(const ServerRpcInfo&) = delete; + ServerRpcInfo(ServerRpcInfo&&) = default; + ServerRpcInfo& operator=(ServerRpcInfo&&) = default; + + // Getter methods + const char* method() { return method_; } + grpc::ServerContext* server_context() { return ctx_; } + + private: + ServerRpcInfo(grpc::ServerContext* ctx, const char* method) + : ctx_(ctx), method_(method) { + ref_.store(1); + } + + // Runs interceptor at pos \a pos. + void RunInterceptor( + experimental::InterceptorBatchMethods* interceptor_methods, size_t pos) { + GPR_CODEGEN_ASSERT(pos < interceptors_.size()); + interceptors_[pos]->Intercept(interceptor_methods); + } + + void RegisterInterceptors( + const std::vector< + std::unique_ptr>& + creators) { + for (const auto& creator : creators) { + interceptors_.push_back(std::unique_ptr( + creator->CreateServerInterceptor(this))); + } + } + + void Ref() { ref_++; } + void Unref() { + if (--ref_ == 0) { + delete this; + } + } + + grpc::ServerContext* ctx_ = nullptr; + const char* method_ = nullptr; + std::atomic_int ref_; + std::vector> interceptors_; + + friend class internal::InterceptorBatchMethodsImpl; + friend class grpc::ServerContext; +}; + +} // namespace experimental +} // namespace grpc + +#endif // GRPCPP_IMPL_CODEGEN_SERVER_INTERCEPTOR_H diff --git a/include/grpcpp/impl/codegen/server_interface.h b/include/grpcpp/impl/codegen/server_interface.h index 237991cde60..92c87a5f7ec 100644 --- a/include/grpcpp/impl/codegen/server_interface.h +++ b/include/grpcpp/impl/codegen/server_interface.h @@ -21,10 +21,12 @@ #include #include +#include #include #include #include #include +#include namespace grpc { @@ -148,44 +150,67 @@ class ServerInterface : public internal::CallHook { public: BaseAsyncRequest(ServerInterface* server, ServerContext* context, internal::ServerAsyncStreamingInterface* stream, - CompletionQueue* call_cq, void* tag, + CompletionQueue* call_cq, + ServerCompletionQueue* notification_cq, void* tag, bool delete_on_finalize); virtual ~BaseAsyncRequest(); bool FinalizeResult(void** tag, bool* status) override; + private: + void ContinueFinalizeResultAfterInterception(); + protected: ServerInterface* const server_; ServerContext* const context_; internal::ServerAsyncStreamingInterface* const stream_; CompletionQueue* const call_cq_; + ServerCompletionQueue* const notification_cq_; void* const tag_; const bool delete_on_finalize_; grpc_call* call_; + internal::Call call_wrapper_; + internal::InterceptorBatchMethodsImpl interceptor_methods_; + bool done_intercepting_; }; class RegisteredAsyncRequest : public BaseAsyncRequest { public: RegisteredAsyncRequest(ServerInterface* server, ServerContext* context, internal::ServerAsyncStreamingInterface* stream, - CompletionQueue* call_cq, void* tag); - - // uses BaseAsyncRequest::FinalizeResult + CompletionQueue* call_cq, + ServerCompletionQueue* notification_cq, void* tag, + const char* name); + + virtual bool FinalizeResult(void** tag, bool* status) override { + /* If we are done intercepting, then there is nothing more for us to do */ + if (done_intercepting_) { + return BaseAsyncRequest::FinalizeResult(tag, status); + } + call_wrapper_ = internal::Call( + call_, server_, call_cq_, server_->max_receive_message_size(), + context_->set_server_rpc_info(name_, + *server_->interceptor_creators())); + return BaseAsyncRequest::FinalizeResult(tag, status); + } protected: void IssueRequest(void* registered_method, grpc_byte_buffer** payload, ServerCompletionQueue* notification_cq); + const char* name_; }; class NoPayloadAsyncRequest final : public RegisteredAsyncRequest { public: - NoPayloadAsyncRequest(void* registered_method, ServerInterface* server, - ServerContext* context, + NoPayloadAsyncRequest(internal::RpcServiceMethod* registered_method, + ServerInterface* server, ServerContext* context, internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, ServerCompletionQueue* notification_cq, void* tag) - : RegisteredAsyncRequest(server, context, stream, call_cq, tag) { - IssueRequest(registered_method, nullptr, notification_cq); + : RegisteredAsyncRequest(server, context, stream, call_cq, + notification_cq, tag, + registered_method->name()) { + IssueRequest(registered_method->server_tag(), nullptr, notification_cq); } // uses RegisteredAsyncRequest::FinalizeResult @@ -194,13 +219,15 @@ class ServerInterface : public internal::CallHook { template class PayloadAsyncRequest final : public RegisteredAsyncRequest { public: - PayloadAsyncRequest(void* registered_method, ServerInterface* server, - ServerContext* context, + PayloadAsyncRequest(internal::RpcServiceMethod* registered_method, + ServerInterface* server, ServerContext* context, internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, ServerCompletionQueue* notification_cq, void* tag, Message* request) - : RegisteredAsyncRequest(server, context, stream, call_cq, tag), + : RegisteredAsyncRequest(server, context, stream, call_cq, + notification_cq, tag, + registered_method->name()), registered_method_(registered_method), server_(server), context_(context), @@ -209,7 +236,8 @@ class ServerInterface : public internal::CallHook { notification_cq_(notification_cq), tag_(tag), request_(request) { - IssueRequest(registered_method, payload_.bbuf_ptr(), notification_cq); + IssueRequest(registered_method->server_tag(), payload_.bbuf_ptr(), + notification_cq); } ~PayloadAsyncRequest() { @@ -217,6 +245,10 @@ class ServerInterface : public internal::CallHook { } bool FinalizeResult(void** tag, bool* status) override { + /* If we are done intercepting, then there is nothing more for us to do */ + if (done_intercepting_) { + return RegisteredAsyncRequest::FinalizeResult(tag, status); + } if (*status) { if (!payload_.Valid() || !SerializationTraits::Deserialize( payload_.bbuf_ptr(), request_) @@ -235,15 +267,20 @@ class ServerInterface : public internal::CallHook { return false; } } + /* Set interception point for recv message */ + interceptor_methods_.AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE); + interceptor_methods_.SetRecvMessage(request_); return RegisteredAsyncRequest::FinalizeResult(tag, status); } private: - void* const registered_method_; + internal::RpcServiceMethod* const registered_method_; ServerInterface* const server_; ServerContext* const context_; internal::ServerAsyncStreamingInterface* const stream_; CompletionQueue* const call_cq_; + ServerCompletionQueue* const notification_cq_; void* const tag_; Message* const request_; @@ -272,9 +309,8 @@ class ServerInterface : public internal::CallHook { ServerCompletionQueue* notification_cq, void* tag, Message* message) { GPR_CODEGEN_ASSERT(method); - new PayloadAsyncRequest(method->server_tag(), this, context, - stream, call_cq, notification_cq, tag, - message); + new PayloadAsyncRequest(method, this, context, stream, call_cq, + notification_cq, tag, message); } void RequestAsyncCall(internal::RpcServiceMethod* method, @@ -283,8 +319,8 @@ class ServerInterface : public internal::CallHook { CompletionQueue* call_cq, ServerCompletionQueue* notification_cq, void* tag) { GPR_CODEGEN_ASSERT(method); - new NoPayloadAsyncRequest(method->server_tag(), this, context, stream, - call_cq, notification_cq, tag); + new NoPayloadAsyncRequest(method, this, context, stream, call_cq, + notification_cq, tag); } void RequestAsyncGenericCall(GenericServerContext* context, @@ -295,6 +331,13 @@ class ServerInterface : public internal::CallHook { new GenericAsyncRequest(this, context, stream, call_cq, notification_cq, tag, true); } + + private: + virtual const std::vector< + std::unique_ptr>* + interceptor_creators() { + return nullptr; + } }; } // namespace grpc diff --git a/include/grpcpp/impl/codegen/sync_stream.h b/include/grpcpp/impl/codegen/sync_stream.h index cbfcf25d0a4..6981076f04e 100644 --- a/include/grpcpp/impl/codegen/sync_stream.h +++ b/include/grpcpp/impl/codegen/sync_stream.h @@ -250,7 +250,7 @@ class ClientReader final : public ClientReaderInterface { ::grpc::internal::CallOpSendMessage, ::grpc::internal::CallOpClientSendClose> ops; - ops.SendInitialMetadata(context->send_initial_metadata_, + ops.SendInitialMetadata(&context->send_initial_metadata_, context->initial_metadata_flags()); // TODO(ctiller): don't assert GPR_CODEGEN_ASSERT(ops.SendMessage(request).ok()); @@ -327,7 +327,7 @@ class ClientWriter : public ClientWriterInterface { ops.ClientSendClose(); } if (context_->initial_metadata_corked_) { - ops.SendInitialMetadata(context_->send_initial_metadata_, + ops.SendInitialMetadata(&context_->send_initial_metadata_, context_->initial_metadata_flags()); context_->set_initial_metadata_corked(false); } @@ -386,7 +386,7 @@ class ClientWriter : public ClientWriterInterface { if (!context_->initial_metadata_corked_) { ::grpc::internal::CallOpSet<::grpc::internal::CallOpSendInitialMetadata> ops; - ops.SendInitialMetadata(context->send_initial_metadata_, + ops.SendInitialMetadata(&context->send_initial_metadata_, context->initial_metadata_flags()); call_.PerformOps(&ops); cq_.Pluck(&ops); @@ -498,7 +498,7 @@ class ClientReaderWriter final : public ClientReaderWriterInterface { ops.ClientSendClose(); } if (context_->initial_metadata_corked_) { - ops.SendInitialMetadata(context_->send_initial_metadata_, + ops.SendInitialMetadata(&context_->send_initial_metadata_, context_->initial_metadata_flags()); context_->set_initial_metadata_corked(false); } @@ -557,7 +557,7 @@ class ClientReaderWriter final : public ClientReaderWriterInterface { if (!context_->initial_metadata_corked_) { ::grpc::internal::CallOpSet<::grpc::internal::CallOpSendInitialMetadata> ops; - ops.SendInitialMetadata(context->send_initial_metadata_, + ops.SendInitialMetadata(&context->send_initial_metadata_, context->initial_metadata_flags()); call_.PerformOps(&ops); cq_.Pluck(&ops); @@ -583,7 +583,7 @@ class ServerReader final : public ServerReaderInterface { GPR_CODEGEN_ASSERT(!ctx_->sent_initial_metadata_); internal::CallOpSet ops; - ops.SendInitialMetadata(ctx_->initial_metadata_, + ops.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { ops.set_compression_level(ctx_->compression_level()); @@ -635,7 +635,7 @@ class ServerWriter final : public ServerWriterInterface { GPR_CODEGEN_ASSERT(!ctx_->sent_initial_metadata_); internal::CallOpSet ops; - ops.SendInitialMetadata(ctx_->initial_metadata_, + ops.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { ops.set_compression_level(ctx_->compression_level()); @@ -660,7 +660,7 @@ class ServerWriter final : public ServerWriterInterface { return false; } if (!ctx_->sent_initial_metadata_) { - ctx_->pending_ops_.SendInitialMetadata(ctx_->initial_metadata_, + ctx_->pending_ops_.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { ctx_->pending_ops_.set_compression_level(ctx_->compression_level()); @@ -708,7 +708,7 @@ class ServerReaderWriterBody final { GPR_CODEGEN_ASSERT(!ctx_->sent_initial_metadata_); CallOpSet ops; - ops.SendInitialMetadata(ctx_->initial_metadata_, + ops.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { ops.set_compression_level(ctx_->compression_level()); @@ -738,7 +738,7 @@ class ServerReaderWriterBody final { return false; } if (!ctx_->sent_initial_metadata_) { - ctx_->pending_ops_.SendInitialMetadata(ctx_->initial_metadata_, + ctx_->pending_ops_.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { ctx_->pending_ops_.set_compression_level(ctx_->compression_level()); diff --git a/include/grpcpp/server.h b/include/grpcpp/server.h index 8d3e856502c..2b89ffd317a 100644 --- a/include/grpcpp/server.h +++ b/include/grpcpp/server.h @@ -174,7 +174,11 @@ class Server : public ServerInterface, private GrpcLibraryCodegen { std::shared_ptr>> sync_server_cqs, int min_pollers, int max_pollers, int sync_cq_timeout_msec, - grpc_resource_quota* server_rq = nullptr); + grpc_resource_quota* server_rq = nullptr, + std::vector< + std::unique_ptr> + interceptor_creators = std::vector>()); /// Start the server. /// @@ -187,6 +191,12 @@ class Server : public ServerInterface, private GrpcLibraryCodegen { grpc_server* server() override { return server_; }; private: + const std::vector< + std::unique_ptr>* + interceptor_creators() override { + return &interceptor_creators_; + } + friend class AsyncGenericService; friend class ServerBuilder; friend class ServerInitializer; @@ -251,6 +261,9 @@ class Server : public ServerInterface, private GrpcLibraryCodegen { // A special handler for resource exhausted in sync case std::unique_ptr resource_exhausted_handler_; + + std::vector> + interceptor_creators_; }; } // namespace grpc diff --git a/include/grpcpp/server_builder.h b/include/grpcpp/server_builder.h index a58a59c2d8b..028b8cffaa7 100644 --- a/include/grpcpp/server_builder.h +++ b/include/grpcpp/server_builder.h @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -212,6 +213,29 @@ class ServerBuilder { /// doc/workarounds.md. ServerBuilder& EnableWorkaround(grpc_workaround_list id); + /// NOTE: class experimental_type is not part of the public API of this class. + /// TODO(yashykt): Integrate into public API when this is no longer + /// experimental. + class experimental_type { + public: + explicit experimental_type(ServerBuilder* builder) : builder_(builder) {} + + void SetInterceptorCreators( + std::vector< + std::unique_ptr> + interceptor_creators) { + builder_->interceptor_creators_ = std::move(interceptor_creators); + } + + private: + ServerBuilder* builder_; + }; + + /// NOTE: The function experimental() is not stable public API. It is a view + /// to the experimental components of this class. It may be changed or removed + /// at any time. + experimental_type experimental() { return experimental_type(this); } + protected: /// Experimental, to be deprecated struct Port { @@ -297,6 +321,8 @@ class ServerBuilder { grpc_compression_algorithm algorithm; } maybe_default_compression_algorithm_; uint32_t enabled_compression_algorithms_bitset_; + std::vector> + interceptor_creators_; }; } // namespace grpc diff --git a/src/cpp/client/channel_cc.cc b/src/cpp/client/channel_cc.cc index 2cab41b3f56..15e3ccb3c9f 100644 --- a/src/cpp/client/channel_cc.cc +++ b/src/cpp/client/channel_cc.cc @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -57,9 +58,8 @@ Channel::Channel( std::unique_ptr>> interceptor_creators) : host_(host), c_channel_(channel) { - auto* vector = interceptor_creators.release(); - if (vector != nullptr) { - interceptor_creators_ = std::move(*vector); + if (interceptor_creators != nullptr) { + interceptor_creators_ = std::move(*interceptor_creators); } g_gli_initializer.summon(); } @@ -112,9 +112,10 @@ void ChannelResetConnectionBackoff(Channel* channel) { } // namespace experimental -internal::Call Channel::CreateCall(const internal::RpcMethod& method, - ClientContext* context, - CompletionQueue* cq) { +internal::Call Channel::CreateCallInternal(const internal::RpcMethod& method, + ClientContext* context, + CompletionQueue* cq, + int interceptor_pos) { const bool kRegistered = method.channel_tag() && context->authority().empty(); grpc_call* c_call = nullptr; if (kRegistered) { @@ -147,17 +148,22 @@ internal::Call Channel::CreateCall(const internal::RpcMethod& method, } grpc_census_call_set_context(c_call, context->census_context()); context->set_call(c_call, shared_from_this()); - return internal::Call(c_call, this, cq); + + auto* info = context->set_client_rpc_info( + method.name(), this, interceptor_creators_, interceptor_pos); + return internal::Call(c_call, this, cq, info); +} + +internal::Call Channel::CreateCall(const internal::RpcMethod& method, + ClientContext* context, + CompletionQueue* cq) { + return CreateCallInternal(method, context, cq, 0); } void Channel::PerformOpsOnCall(internal::CallOpSetInterface* ops, internal::Call* call) { - static const size_t MAX_OPS = 8; - size_t nops = 0; - grpc_op cops[MAX_OPS]; - ops->FillOps(call->call(), cops, &nops); - GPR_ASSERT(GRPC_CALL_OK == grpc_call_start_batch(call->call(), cops, nops, - ops->cq_tag(), nullptr)); + ops->FillOps( + call); // Make a copy of call. It's fine since Call just has pointers } void* Channel::RegisterMethod(const char* method) { diff --git a/src/cpp/common/core_codegen.cc b/src/cpp/common/core_codegen.cc index 619aacadaa0..cfaa2e7b193 100644 --- a/src/cpp/common/core_codegen.cc +++ b/src/cpp/common/core_codegen.cc @@ -102,6 +102,13 @@ size_t CoreCodegen::grpc_byte_buffer_length(grpc_byte_buffer* bb) { return ::grpc_byte_buffer_length(bb); } +grpc_call_error CoreCodegen::grpc_call_start_batch(grpc_call* call, + const grpc_op* ops, + size_t nops, void* tag, + void* reserved) { + return ::grpc_call_start_batch(call, ops, nops, tag, reserved); +} + grpc_call_error CoreCodegen::grpc_call_cancel_with_status( grpc_call* call, grpc_status_code status, const char* description, void* reserved) { diff --git a/src/cpp/server/server_builder.cc b/src/cpp/server/server_builder.cc index 8417c45e641..fc42b6c8868 100644 --- a/src/cpp/server/server_builder.cc +++ b/src/cpp/server/server_builder.cc @@ -263,7 +263,8 @@ std::unique_ptr ServerBuilder::BuildAndStart() { std::unique_ptr server(new Server( max_receive_message_size_, &args, sync_server_cqs, sync_server_settings_.min_pollers, sync_server_settings_.max_pollers, - sync_server_settings_.cq_timeout_msec, resource_quota_)); + sync_server_settings_.cq_timeout_msec, resource_quota_, + std::move(interceptor_creators_))); if (has_sync_methods) { // This is a Sync server diff --git a/src/cpp/server/server_cc.cc b/src/cpp/server/server_cc.cc index 7aeddff6435..59a531e2725 100644 --- a/src/cpp/server/server_cc.cc +++ b/src/cpp/server/server_cc.cc @@ -27,7 +27,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -38,8 +40,10 @@ #include #include "src/core/ext/transport/inproc/inproc_transport.h" +#include "src/core/lib/iomgr/exec_ctx.h" #include "src/core/lib/profiling/timers.h" #include "src/core/lib/surface/call.h" +#include "src/core/lib/surface/completion_queue.h" #include "src/cpp/client/create_channel_internal.h" #include "src/cpp/server/health/default_health_check_service.h" #include "src/cpp/thread_manager/thread_manager.h" @@ -127,10 +131,13 @@ class Server::UnimplementedAsyncResponse final ~UnimplementedAsyncResponse() { delete request_; } bool FinalizeResult(void** tag, bool* status) override { - internal::CallOpSet< - internal::CallOpSendInitialMetadata, - internal::CallOpServerSendStatus>::FinalizeResult(tag, status); - delete this; + if (internal::CallOpSet< + internal::CallOpSendInitialMetadata, + internal::CallOpServerSendStatus>::FinalizeResult(tag, status)) { + delete this; + } else { + // The tag was swallowed due to interception. We will see it again. + } return false; } @@ -208,13 +215,18 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { public: explicit CallData(Server* server, SyncRequest* mrd) : cq_(mrd->cq_), - call_(mrd->call_, server, &cq_, server->max_receive_message_size()), ctx_(mrd->deadline_, &mrd->request_metadata_), has_request_payload_(mrd->has_request_payload_), request_payload_(has_request_payload_ ? mrd->request_payload_ : nullptr), + request_(nullptr), method_(mrd->method_), - server_(server) { + call_(mrd->call_, server, &cq_, server->max_receive_message_size(), + ctx_.set_server_rpc_info(method_->name(), + server->interceptor_creators_)), + server_(server), + global_callbacks_(nullptr), + resources_(false) { ctx_.set_call(mrd->call_); ctx_.cq_ = &cq_; GPR_ASSERT(mrd->in_flight_); @@ -230,33 +242,73 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { void Run(const std::shared_ptr& global_callbacks, bool resources) { - ctx_.BeginCompletionOp(&call_); - global_callbacks->PreSynchronousRequest(&ctx_); - auto* handler = resources ? method_->handler() - : server_->resource_exhausted_handler_.get(); - handler->RunHandler(internal::MethodHandler::HandlerParameter( - &call_, &ctx_, request_payload_)); - global_callbacks->PostSynchronousRequest(&ctx_); - request_payload_ = nullptr; - - cq_.Shutdown(); + global_callbacks_ = global_callbacks; + resources_ = resources; + + interceptor_methods_.SetCall(&call_); + interceptor_methods_.SetReverse(); + // Set interception point for RECV INITIAL METADATA + interceptor_methods_.AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA); + interceptor_methods_.SetRecvInitialMetadata(&ctx_.client_metadata_); + + if (has_request_payload_) { + // Set interception point for RECV MESSAGE + auto* handler = resources_ ? method_->handler() + : server_->resource_exhausted_handler_.get(); + request_ = handler->Deserialize(request_payload_, &request_status_); + + request_payload_ = nullptr; + interceptor_methods_.AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE); + interceptor_methods_.SetRecvMessage(request_); + } - internal::CompletionQueueTag* op_tag = ctx_.GetCompletionOpTag(); - cq_.TryPluck(op_tag, gpr_inf_future(GPR_CLOCK_REALTIME)); + auto f = std::bind(&CallData::ContinueRunAfterInterception, this); + if (interceptor_methods_.RunInterceptors(f)) { + ContinueRunAfterInterception(); + } else { + // There were interceptors to be run, so ContinueRunAfterInterception + // will be run when interceptors are done. + } + } - /* Ensure the cq_ is shutdown */ - DummyTag ignored_tag; - GPR_ASSERT(cq_.Pluck(&ignored_tag) == false); + void ContinueRunAfterInterception() { + { + ctx_.BeginCompletionOp(&call_); + global_callbacks_->PreSynchronousRequest(&ctx_); + auto* handler = resources_ ? method_->handler() + : server_->resource_exhausted_handler_.get(); + handler->RunHandler(internal::MethodHandler::HandlerParameter( + &call_, &ctx_, request_, request_status_)); + request_ = nullptr; + global_callbacks_->PostSynchronousRequest(&ctx_); + + cq_.Shutdown(); + + internal::CompletionQueueTag* op_tag = ctx_.GetCompletionOpTag(); + cq_.TryPluck(op_tag, gpr_inf_future(GPR_CLOCK_REALTIME)); + + /* Ensure the cq_ is shutdown */ + DummyTag ignored_tag; + GPR_ASSERT(cq_.Pluck(&ignored_tag) == false); + } + delete this; } private: CompletionQueue cq_; - internal::Call call_; ServerContext ctx_; const bool has_request_payload_; grpc_byte_buffer* request_payload_; + void* request_; + Status request_status_; internal::RpcServiceMethod* const method_; + internal::Call call_; Server* server_; + std::shared_ptr global_callbacks_; + bool resources_; + internal::InterceptorBatchMethodsImpl interceptor_methods_; }; private: @@ -318,8 +370,9 @@ class Server::SyncRequestThreadManager : public ThreadManager { } if (ok) { - // Calldata takes ownership of the completion queue inside sync_req - SyncRequest::CallData cd(server_, sync_req); + // Calldata takes ownership of the completion queue and interceptors + // inside sync_req + auto* cd = new SyncRequest::CallData(server_, sync_req); // Prepare for the next request if (!IsShutdown()) { sync_req->SetupRequest(); // Create new completion queue for sync_req @@ -327,7 +380,7 @@ class Server::SyncRequestThreadManager : public ThreadManager { } GPR_TIMER_SCOPE("cd.Run()", 0); - cd.Run(global_callbacks_, resources); + cd->Run(global_callbacks_, resources); } // TODO (sreek) If ok is false here (which it isn't in case of // grpc_request_registered_call), we should still re-queue the request @@ -389,7 +442,10 @@ Server::Server( std::shared_ptr>> sync_server_cqs, int min_pollers, int max_pollers, int sync_cq_timeout_msec, - grpc_resource_quota* server_rq) + grpc_resource_quota* server_rq, + std::vector< + std::unique_ptr> + interceptor_creators) : max_receive_message_size_(max_receive_message_size), sync_server_cqs_(std::move(sync_server_cqs)), started_(false), @@ -398,7 +454,8 @@ Server::Server( has_generic_service_(false), server_(nullptr), server_initializer_(new ServerInitializer(this)), - health_check_service_disabled_(false) { + health_check_service_disabled_(false), + interceptor_creators_(std::move(interceptor_creators)) { g_gli_initializer.summon(); gpr_once_init(&g_once_init_callbacks, InitGlobalCallbacks); global_callbacks_ = g_callbacks; @@ -681,31 +738,27 @@ void Server::Wait() { void Server::PerformOpsOnCall(internal::CallOpSetInterface* ops, internal::Call* call) { - static const size_t MAX_OPS = 8; - size_t nops = 0; - grpc_op cops[MAX_OPS]; - ops->FillOps(call->call(), cops, &nops); - auto result = - grpc_call_start_batch(call->call(), cops, nops, ops->cq_tag(), nullptr); - if (result != GRPC_CALL_OK) { - gpr_log(GPR_ERROR, "Fatal: grpc_call_start_batch returned %d", result); - grpc_call_log_batch(__FILE__, __LINE__, GPR_LOG_SEVERITY_ERROR, - call->call(), cops, nops, ops); - abort(); - } + ops->FillOps(call); } ServerInterface::BaseAsyncRequest::BaseAsyncRequest( ServerInterface* server, ServerContext* context, internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, - void* tag, bool delete_on_finalize) + ServerCompletionQueue* notification_cq, void* tag, bool delete_on_finalize) : server_(server), context_(context), stream_(stream), call_cq_(call_cq), + notification_cq_(notification_cq), tag_(tag), delete_on_finalize_(delete_on_finalize), - call_(nullptr) { + call_(nullptr), + done_intercepting_(false) { + /* Set up interception state partially for the receive ops. call_wrapper_ is + * not filled at this point, but it will be filled before the interceptors are + * run. */ + interceptor_methods_.SetCall(&call_wrapper_); + interceptor_methods_.SetReverse(); call_cq_->RegisterAvalanching(); // This op will trigger more ops } @@ -715,15 +768,45 @@ ServerInterface::BaseAsyncRequest::~BaseAsyncRequest() { bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag, bool* status) { + if (done_intercepting_) { + *tag = tag_; + if (delete_on_finalize_) { + delete this; + } + return true; + } context_->set_call(call_); context_->cq_ = call_cq_; - internal::Call call(call_, server_, call_cq_, - server_->max_receive_message_size()); - if (*status && call_) { - context_->BeginCompletionOp(&call); + if (call_wrapper_.call() == nullptr) { + // Fill it since it is empty. + call_wrapper_ = internal::Call( + call_, server_, call_cq_, server_->max_receive_message_size(), nullptr); } + // just the pointers inside call are copied here - stream_->BindCall(&call); + stream_->BindCall(&call_wrapper_); + + if (*status && call_ && call_wrapper_.server_rpc_info()) { + done_intercepting_ = true; + // Set interception point for RECV INITIAL METADATA + interceptor_methods_.AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA); + interceptor_methods_.SetRecvInitialMetadata(&context_->client_metadata_); + auto f = std::bind(&ServerInterface::BaseAsyncRequest:: + ContinueFinalizeResultAfterInterception, + this); + if (interceptor_methods_.RunInterceptors(f)) { + // There are no interceptors to run. Continue + } else { + // There were interceptors to be run, so + // ContinueFinalizeResultAfterInterception will be run when interceptors + // are done. + return false; + } + } + if (*status && call_) { + context_->BeginCompletionOp(&call_wrapper_); + } *tag = tag_; if (delete_on_finalize_) { delete this; @@ -731,11 +814,25 @@ bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag, return true; } +void ServerInterface::BaseAsyncRequest:: + ContinueFinalizeResultAfterInterception() { + context_->BeginCompletionOp(&call_wrapper_); + // Queue a tag which will be returned immediately + grpc_core::ExecCtx exec_ctx; + grpc_cq_begin_op(notification_cq_->cq(), this); + grpc_cq_end_op( + notification_cq_->cq(), this, GRPC_ERROR_NONE, + [](void* arg, grpc_cq_completion* completion) { delete completion; }, + nullptr, new grpc_cq_completion()); +} + ServerInterface::RegisteredAsyncRequest::RegisteredAsyncRequest( ServerInterface* server, ServerContext* context, internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, - void* tag) - : BaseAsyncRequest(server, context, stream, call_cq, tag, true) {} + ServerCompletionQueue* notification_cq, void* tag, const char* name) + : BaseAsyncRequest(server, context, stream, call_cq, notification_cq, tag, + true), + name_(name) {} void ServerInterface::RegisteredAsyncRequest::IssueRequest( void* registered_method, grpc_byte_buffer** payload, @@ -751,7 +848,7 @@ ServerInterface::GenericAsyncRequest::GenericAsyncRequest( ServerInterface* server, GenericServerContext* context, internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, ServerCompletionQueue* notification_cq, void* tag, bool delete_on_finalize) - : BaseAsyncRequest(server, context, stream, call_cq, tag, + : BaseAsyncRequest(server, context, stream, call_cq, notification_cq, tag, delete_on_finalize) { grpc_call_details_init(&call_details_); GPR_ASSERT(notification_cq); @@ -764,6 +861,10 @@ ServerInterface::GenericAsyncRequest::GenericAsyncRequest( bool ServerInterface::GenericAsyncRequest::FinalizeResult(void** tag, bool* status) { + // If we are done intercepting, there is nothing more for us to do + if (done_intercepting_) { + return BaseAsyncRequest::FinalizeResult(tag, status); + } // TODO(yangg) remove the copy here. if (*status) { static_cast(context_)->method_ = @@ -774,16 +875,26 @@ bool ServerInterface::GenericAsyncRequest::FinalizeResult(void** tag, } grpc_slice_unref(call_details_.method); grpc_slice_unref(call_details_.host); + call_wrapper_ = internal::Call( + call_, server_, call_cq_, server_->max_receive_message_size(), + context_->set_server_rpc_info( + static_cast(context_)->method_.c_str(), + *server_->interceptor_creators())); return BaseAsyncRequest::FinalizeResult(tag, status); } bool Server::UnimplementedAsyncRequest::FinalizeResult(void** tag, bool* status) { - if (GenericAsyncRequest::FinalizeResult(tag, status) && *status) { - new UnimplementedAsyncRequest(server_, cq_); - new UnimplementedAsyncResponse(this); + if (GenericAsyncRequest::FinalizeResult(tag, status)) { + // We either had no interceptors run or we are done intercepting + if (*status) { + new UnimplementedAsyncRequest(server_, cq_); + new UnimplementedAsyncResponse(this); + } else { + delete this; + } } else { - delete this; + // The tag was swallowed due to interception. We will see it again. } return false; } diff --git a/src/cpp/server/server_context.cc b/src/cpp/server/server_context.cc index bd532a968da..995e7877854 100644 --- a/src/cpp/server/server_context.cc +++ b/src/cpp/server/server_context.cc @@ -41,13 +41,22 @@ class ServerContext::CompletionOp final : public internal::CallOpSetInterface { public: // initial refs: one in the server context, one in the cq // must ref the call before calling constructor and after deleting this - CompletionOp(grpc_call* call) - : call_(call), + CompletionOp(internal::Call* call) + : call_(*call), has_tag_(false), tag_(nullptr), refs_(2), finalized_(false), - cancelled_(0) {} + cancelled_(0), + done_intercepting_(false) {} + + ~CompletionOp() { + if (call_.server_rpc_info()) { + call_.server_rpc_info()->Unref(); + } + } + + void FillOps(internal::Call* call) override; // This should always be arena allocated in the call, so override delete. // But this class is not trivially destructible, so must actually call delete @@ -63,7 +72,6 @@ class ServerContext::CompletionOp final : public internal::CallOpSetInterface { // there are no tests catching the compiler warning. static void operator delete(void*, void*) { assert(0); } - void FillOps(grpc_call* call, grpc_op* ops, size_t* nops) override; bool FinalizeResult(void** tag, bool* status) override; bool CheckCancelled(CompletionQueue* cq) { @@ -82,58 +90,121 @@ class ServerContext::CompletionOp final : public internal::CallOpSetInterface { void Unref(); + // This will be called while interceptors are run if the RPC is a hijacked + // RPC. This should set hijacking state for each of the ops. + void SetHijackingState() override { + /* Servers don't allow hijacking */ + GPR_CODEGEN_ASSERT(false); + } + + /* Should be called after interceptors are done running */ + void ContinueFillOpsAfterInterception() override {} + + /* Should be called after interceptors are done running on the finalize result + * path */ + void ContinueFinalizeResultAfterInterception() override { + done_intercepting_ = true; + if (!has_tag_) { + /* We don't have a tag to return. */ + std::unique_lock lock(mu_); + if (--refs_ == 0) { + lock.unlock(); + grpc_call* call = call_.call(); + delete this; + grpc_call_unref(call); + } + return; + } + /* Start a dummy op so that we can return the tag */ + GPR_CODEGEN_ASSERT(GRPC_CALL_OK == + g_core_codegen_interface->grpc_call_start_batch( + call_.call(), nullptr, 0, this, nullptr)); + } + private: bool CheckCancelledNoPluck() { std::lock_guard g(mu_); return finalized_ ? (cancelled_ != 0) : false; } - grpc_call* call_; + internal::Call call_; bool has_tag_; void* tag_; std::mutex mu_; int refs_; bool finalized_; int cancelled_; + bool done_intercepting_; + internal::InterceptorBatchMethodsImpl interceptor_methods_; }; void ServerContext::CompletionOp::Unref() { std::unique_lock lock(mu_); if (--refs_ == 0) { lock.unlock(); - // Save aside the call pointer before deleting for later unref - grpc_call* call = call_; + grpc_call* call = call_.call(); delete this; grpc_call_unref(call); } } -void ServerContext::CompletionOp::FillOps(grpc_call* call, grpc_op* ops, - size_t* nops) { - ops->op = GRPC_OP_RECV_CLOSE_ON_SERVER; - ops->data.recv_close_on_server.cancelled = &cancelled_; - ops->flags = 0; - ops->reserved = nullptr; - *nops = 1; +void ServerContext::CompletionOp::FillOps(internal::Call* call) { + grpc_op ops; + ops.op = GRPC_OP_RECV_CLOSE_ON_SERVER; + ops.data.recv_close_on_server.cancelled = &cancelled_; + ops.flags = 0; + ops.reserved = nullptr; + interceptor_methods_.SetCall(&call_); + interceptor_methods_.SetReverse(); + interceptor_methods_.SetCallOpSetInterface(this); + GPR_ASSERT(GRPC_CALL_OK == + grpc_call_start_batch(call->call(), &ops, 1, this, nullptr)); + /* No interceptors to run here */ } bool ServerContext::CompletionOp::FinalizeResult(void** tag, bool* status) { - std::unique_lock lock(mu_); - finalized_ = true; bool ret = false; - if (has_tag_) { - *tag = tag_; - ret = true; + std::unique_lock lock(mu_); + if (done_intercepting_) { + /* We are done intercepting. */ + if (has_tag_) { + *tag = tag_; + ret = true; + } + if (--refs_ == 0) { + lock.unlock(); + grpc_call* call = call_.call(); + delete this; + grpc_call_unref(call); + } + return ret; } + finalized_ = true; + if (!*status) cancelled_ = 1; - if (--refs_ == 0) { - lock.unlock(); - // Save aside the call pointer before deleting for later unref - grpc_call* call = call_; - delete this; - grpc_call_unref(call); + /* Release the lock since we are going to be running through interceptors now + */ + lock.unlock(); + /* Add interception point and run through interceptors */ + interceptor_methods_.AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_CLOSE); + if (interceptor_methods_.RunInterceptors()) { + /* No interceptors were run */ + if (has_tag_) { + *tag = tag_; + ret = true; + } + lock.lock(); + if (--refs_ == 0) { + lock.unlock(); + grpc_call* call = call_.call(); + delete this; + grpc_call_unref(call); + } + return ret; } - return ret; + /* There are interceptors to be run. Return false for now */ + return false; } // ServerContext body @@ -169,14 +240,20 @@ ServerContext::~ServerContext() { if (completion_op_) { completion_op_->Unref(); } + if (rpc_info_) { + rpc_info_->Unref(); + } } void ServerContext::BeginCompletionOp(internal::Call* call) { GPR_ASSERT(!completion_op_); + if (rpc_info_) { + rpc_info_->Ref(); + } grpc_call_ref(call->call()); completion_op_ = new (grpc_call_arena_alloc(call->call(), sizeof(CompletionOp))) - CompletionOp(call->call()); + CompletionOp(call); if (has_notify_when_done_tag_) { completion_op_->set_tag(async_notify_when_done_tag_); } diff --git a/test/cpp/end2end/BUILD b/test/cpp/end2end/BUILD index 0415efc1ef9..235249e8bf6 100644 --- a/test/cpp/end2end/BUILD +++ b/test/cpp/end2end/BUILD @@ -35,6 +35,19 @@ grpc_cc_library( ], ) +grpc_cc_library( + name = "interceptors_util", + testonly = True, + hdrs = ["interceptors_util.h"], + external_deps = [ + "gtest", + ], + deps = [ + "//src/proto/grpc/testing:echo_proto", + "//test/cpp/util:test_util", + ], +) + grpc_cc_test( name = "async_end2end_test", srcs = ["async_end2end_test.cc"], @@ -117,6 +130,26 @@ grpc_cc_test( ], ) +grpc_cc_test( + name = "client_interceptors_end2end_test", + srcs = ["client_interceptors_end2end_test.cc"], + external_deps = [ + "gtest", + ], + deps = [ + ":interceptors_util", + ":test_service_impl", + "//:gpr", + "//:grpc", + "//:grpc++", + "//src/proto/grpc/testing:echo_messages_proto", + "//src/proto/grpc/testing:echo_proto", + "//test/core/util:gpr_test_util", + "//test/core/util:grpc_test_util", + "//test/cpp/util:test_util", + ], +) + grpc_cc_library( name = "end2end_test_lib", testonly = True, @@ -469,6 +502,26 @@ grpc_cc_binary( ], ) +grpc_cc_test( + name = "server_interceptors_end2end_test", + srcs = ["server_interceptors_end2end_test.cc"], + external_deps = [ + "gtest", + ], + deps = [ + ":interceptors_util", + ":test_service_impl", + "//:gpr", + "//:grpc", + "//:grpc++", + "//src/proto/grpc/testing:echo_messages_proto", + "//src/proto/grpc/testing:echo_proto", + "//test/core/util:gpr_test_util", + "//test/core/util:grpc_test_util", + "//test/cpp/util:test_util", + ], +) + grpc_cc_test( name = "server_load_reporting_end2end_test", srcs = ["server_load_reporting_end2end_test.cc"], diff --git a/test/cpp/end2end/client_interceptors_end2end_test.cc b/test/cpp/end2end/client_interceptors_end2end_test.cc new file mode 100644 index 00000000000..5720e87478f --- /dev/null +++ b/test/cpp/end2end/client_interceptors_end2end_test.cc @@ -0,0 +1,606 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/interceptors_util.h" +#include "test/cpp/end2end/test_service_impl.h" +#include "test/cpp/util/byte_buffer_proto_helper.h" +#include "test/cpp/util/string_ref_helper.h" + +#include + +namespace grpc { +namespace testing { +namespace { + +class ClientInterceptorsStreamingEnd2endTest : public ::testing::Test { + protected: + ClientInterceptorsStreamingEnd2endTest() { + int port = grpc_pick_unused_port_or_die(); + + ServerBuilder builder; + server_address_ = "localhost:" + std::to_string(port); + builder.AddListeningPort(server_address_, InsecureServerCredentials()); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + } + + ~ClientInterceptorsStreamingEnd2endTest() { server_->Shutdown(); } + + std::string server_address_; + EchoTestServiceStreamingImpl service_; + std::unique_ptr server_; +}; + +class ClientInterceptorsEnd2endTest : public ::testing::Test { + protected: + ClientInterceptorsEnd2endTest() { + int port = grpc_pick_unused_port_or_die(); + + ServerBuilder builder; + server_address_ = "localhost:" + std::to_string(port); + builder.AddListeningPort(server_address_, InsecureServerCredentials()); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + } + + ~ClientInterceptorsEnd2endTest() { server_->Shutdown(); } + + std::string server_address_; + TestServiceImpl service_; + std::unique_ptr server_; +}; + +/* This interceptor does nothing. Just keeps a global count on the number of + * times it was invoked. */ +class DummyInterceptor : public experimental::Interceptor { + public: + DummyInterceptor(experimental::ClientRpcInfo* info) {} + + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + num_times_run_++; + } else if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints:: + POST_RECV_INITIAL_METADATA)) { + num_times_run_reverse_++; + } + methods->Proceed(); + } + + static void Reset() { + num_times_run_.store(0); + num_times_run_reverse_.store(0); + } + + static int GetNumTimesRun() { + EXPECT_EQ(num_times_run_.load(), num_times_run_reverse_.load()); + return num_times_run_.load(); + } + + private: + static std::atomic num_times_run_; + static std::atomic num_times_run_reverse_; +}; + +std::atomic DummyInterceptor::num_times_run_; +std::atomic DummyInterceptor::num_times_run_reverse_; + +class DummyInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new DummyInterceptor(info); + } +}; + +/* Hijacks Echo RPC and fills in the expected values */ +class HijackingInterceptor : public experimental::Interceptor { + public: + HijackingInterceptor(experimental::ClientRpcInfo* info) { + info_ = info; + // Make sure it is the right method + EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0); + } + + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + bool hijack = false; + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + auto* map = methods->GetSendInitialMetadata(); + // Check that we can see the test metadata + ASSERT_EQ(map->size(), static_cast(1)); + auto iterator = map->begin(); + EXPECT_EQ("testkey", iterator->first); + EXPECT_EQ("testvalue", iterator->second); + hijack = true; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + EchoRequest req; + auto* buffer = methods->GetSendMessage(); + auto copied_buffer = *buffer; + SerializationTraits::Deserialize(&copied_buffer, &req); + EXPECT_EQ(req.message(), "Hello"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { + // Got nothing to do here for now + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) { + auto* map = methods->GetRecvInitialMetadata(); + // Got nothing better to do here for now + EXPECT_EQ(map->size(), static_cast(0)); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { + EchoResponse* resp = + static_cast(methods->GetRecvMessage()); + // Check that we got the hijacked message, and re-insert the expected + // message + EXPECT_EQ(resp->message(), "Hello1"); + resp->set_message("Hello"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + bool found = false; + // Check that we received the metadata as an echo + for (const auto& pair : *map) { + found = pair.first.starts_with("testkey") && + pair.second.starts_with("testvalue"); + if (found) break; + } + EXPECT_EQ(found, true); + auto* status = methods->GetRecvStatus(); + EXPECT_EQ(status->ok(), true); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) { + auto* map = methods->GetRecvInitialMetadata(); + // Got nothing better to do here at the moment + EXPECT_EQ(map->size(), static_cast(0)); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) { + // Insert a different message than expected + EchoResponse* resp = + static_cast(methods->GetRecvMessage()); + resp->set_message("Hello1"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + // insert the metadata that we want + EXPECT_EQ(map->size(), static_cast(0)); + map->insert(std::make_pair("testkey", "testvalue")); + auto* status = methods->GetRecvStatus(); + *status = Status(StatusCode::OK, ""); + } + if (hijack) { + methods->Hijack(); + } else { + methods->Proceed(); + } + } + + private: + experimental::ClientRpcInfo* info_; +}; + +class HijackingInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new HijackingInterceptor(info); + } +}; + +class HijackingInterceptorMakesAnotherCall : public experimental::Interceptor { + public: + HijackingInterceptorMakesAnotherCall(experimental::ClientRpcInfo* info) { + info_ = info; + // Make sure it is the right method + EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0); + } + + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + auto* map = methods->GetSendInitialMetadata(); + // Check that we can see the test metadata + ASSERT_EQ(map->size(), static_cast(1)); + auto iterator = map->begin(); + EXPECT_EQ("testkey", iterator->first); + EXPECT_EQ("testvalue", iterator->second); + // Make a copy of the map + metadata_map_ = *map; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + EchoRequest req; + auto* buffer = methods->GetSendMessage(); + auto copied_buffer = *buffer; + SerializationTraits::Deserialize(&copied_buffer, &req); + EXPECT_EQ(req.message(), "Hello"); + req_ = req; + stub_ = grpc::testing::EchoTestService::NewStub( + methods->GetInterceptedChannel()); + ctx_.AddMetadata(metadata_map_.begin()->first, + metadata_map_.begin()->second); + stub_->experimental_async()->Echo(&ctx_, &req_, &resp_, + [this, methods](Status s) { + EXPECT_EQ(s.ok(), true); + EXPECT_EQ(resp_.message(), "Hello"); + methods->Hijack(); + }); + // There isn't going to be any other interesting operation in this batch, + // so it is fine to return + return; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { + // Got nothing to do here for now + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) { + auto* map = methods->GetRecvInitialMetadata(); + // Got nothing better to do here for now + EXPECT_EQ(map->size(), static_cast(0)); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { + EchoResponse* resp = + static_cast(methods->GetRecvMessage()); + // Check that we got the hijacked message, and re-insert the expected + // message + EXPECT_EQ(resp->message(), "Hello"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + bool found = false; + // Check that we received the metadata as an echo + for (const auto& pair : *map) { + found = pair.first.starts_with("testkey") && + pair.second.starts_with("testvalue"); + if (found) break; + } + EXPECT_EQ(found, true); + auto* status = methods->GetRecvStatus(); + EXPECT_EQ(status->ok(), true); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) { + auto* map = methods->GetRecvInitialMetadata(); + // Got nothing better to do here at the moment + EXPECT_EQ(map->size(), static_cast(0)); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) { + // Insert a different message than expected + EchoResponse* resp = + static_cast(methods->GetRecvMessage()); + resp->set_message(resp_.message()); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + // insert the metadata that we want + EXPECT_EQ(map->size(), static_cast(0)); + map->insert(std::make_pair("testkey", "testvalue")); + auto* status = methods->GetRecvStatus(); + *status = Status(StatusCode::OK, ""); + } + + methods->Proceed(); + } + + private: + experimental::ClientRpcInfo* info_; + std::multimap metadata_map_; + ClientContext ctx_; + EchoRequest req_; + EchoResponse resp_; + std::unique_ptr stub_; +}; + +class HijackingInterceptorMakesAnotherCallFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new HijackingInterceptorMakesAnotherCall(info); + } +}; + +class LoggingInterceptor : public experimental::Interceptor { + public: + LoggingInterceptor(experimental::ClientRpcInfo* info) { info_ = info; } + + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + auto* map = methods->GetSendInitialMetadata(); + // Check that we can see the test metadata + ASSERT_EQ(map->size(), static_cast(1)); + auto iterator = map->begin(); + EXPECT_EQ("testkey", iterator->first); + EXPECT_EQ("testvalue", iterator->second); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + EchoRequest req; + auto* buffer = methods->GetSendMessage(); + auto copied_buffer = *buffer; + SerializationTraits::Deserialize(&copied_buffer, &req); + EXPECT_TRUE(req.message().find("Hello") == 0); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { + // Got nothing to do here for now + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) { + auto* map = methods->GetRecvInitialMetadata(); + // Got nothing better to do here for now + EXPECT_EQ(map->size(), static_cast(0)); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { + EchoResponse* resp = + static_cast(methods->GetRecvMessage()); + EXPECT_TRUE(resp->message().find("Hello") == 0); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + bool found = false; + // Check that we received the metadata as an echo + for (const auto& pair : *map) { + found = pair.first.starts_with("testkey") && + pair.second.starts_with("testvalue"); + if (found) break; + } + EXPECT_EQ(found, true); + auto* status = methods->GetRecvStatus(); + EXPECT_EQ(status->ok(), true); + } + methods->Proceed(); + } + + private: + experimental::ClientRpcInfo* info_; +}; + +class LoggingInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new LoggingInterceptor(info); + } +}; + +TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLoggingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + auto creators = std::unique_ptr>>( + new std::vector< + std::unique_ptr>()); + creators->push_back(std::unique_ptr( + new LoggingInterceptorFactory())); + // Add 20 dummy interceptors + for (auto i = 0; i < 20; i++) { + creators->push_back(std::unique_ptr( + new DummyInterceptorFactory())); + } + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeCall(channel); + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); +} + +TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorHijackingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + auto creators = std::unique_ptr>>( + new std::vector< + std::unique_ptr>()); + // Add 20 dummy interceptors before hijacking interceptor + for (auto i = 0; i < 20; i++) { + creators->push_back(std::unique_ptr( + new DummyInterceptorFactory())); + } + creators->push_back(std::unique_ptr( + new HijackingInterceptorFactory())); + // Add 20 dummy interceptors after hijacking interceptor + for (auto i = 0; i < 20; i++) { + creators->push_back(std::unique_ptr( + new DummyInterceptorFactory())); + } + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + + MakeCall(channel); + // Make sure only 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); +} + +TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLogThenHijackTest) { + ChannelArguments args; + auto creators = std::unique_ptr>>( + new std::vector< + std::unique_ptr>()); + creators->push_back(std::unique_ptr( + new LoggingInterceptorFactory())); + creators->push_back(std::unique_ptr( + new HijackingInterceptorFactory())); + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + + MakeCall(channel); +} + +TEST_F(ClientInterceptorsEnd2endTest, + ClientInterceptorHijackingMakesAnotherCallTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + auto creators = std::unique_ptr>>( + new std::vector< + std::unique_ptr>()); + // Add 5 dummy interceptors before hijacking interceptor + for (auto i = 0; i < 5; i++) { + creators->push_back(std::unique_ptr( + new DummyInterceptorFactory())); + } + creators->push_back( + std::unique_ptr( + new HijackingInterceptorMakesAnotherCallFactory())); + // Add 7 dummy interceptors after hijacking interceptor + for (auto i = 0; i < 7; i++) { + creators->push_back(std::unique_ptr( + new DummyInterceptorFactory())); + } + auto channel = server_->experimental().InProcessChannelWithInterceptors( + args, std::move(creators)); + + MakeCall(channel); + // Make sure all interceptors were run once, since the hijacking interceptor + // makes an RPC on the intercepted channel + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 12); +} + +TEST_F(ClientInterceptorsEnd2endTest, + ClientInterceptorLoggingTestWithCallback) { + ChannelArguments args; + DummyInterceptor::Reset(); + auto creators = std::unique_ptr>>( + new std::vector< + std::unique_ptr>()); + creators->push_back(std::unique_ptr( + new LoggingInterceptorFactory())); + // Add 20 dummy interceptors + for (auto i = 0; i < 20; i++) { + creators->push_back(std::unique_ptr( + new DummyInterceptorFactory())); + } + auto channel = server_->experimental().InProcessChannelWithInterceptors( + args, std::move(creators)); + MakeCallbackCall(channel); + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); +} + +TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + auto creators = std::unique_ptr>>( + new std::vector< + std::unique_ptr>()); + creators->push_back(std::unique_ptr( + new LoggingInterceptorFactory())); + // Add 20 dummy interceptors + for (auto i = 0; i < 20; i++) { + creators->push_back(std::unique_ptr( + new DummyInterceptorFactory())); + } + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeClientStreamingCall(channel); + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); +} + +TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + auto creators = std::unique_ptr>>( + new std::vector< + std::unique_ptr>()); + creators->push_back(std::unique_ptr( + new LoggingInterceptorFactory())); + // Add 20 dummy interceptors + for (auto i = 0; i < 20; i++) { + creators->push_back(std::unique_ptr( + new DummyInterceptorFactory())); + } + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeServerStreamingCall(channel); + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); +} + +TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + auto creators = std::unique_ptr>>( + new std::vector< + std::unique_ptr>()); + creators->push_back(std::unique_ptr( + new LoggingInterceptorFactory())); + // Add 20 dummy interceptors + for (auto i = 0; i < 20; i++) { + creators->push_back(std::unique_ptr( + new DummyInterceptorFactory())); + } + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeBidiStreamingCall(channel); + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc_test_init(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/end2end/interceptors_util.h b/test/cpp/end2end/interceptors_util.h new file mode 100644 index 00000000000..5f0aa37dc06 --- /dev/null +++ b/test/cpp/end2end/interceptors_util.h @@ -0,0 +1,308 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/cpp/util/string_ref_helper.h" + +#include + +namespace grpc { +namespace testing { +class EchoTestServiceStreamingImpl : public EchoTestService::Service { + public: + ~EchoTestServiceStreamingImpl() override {} + + Status BidiStream( + ServerContext* context, + grpc::ServerReaderWriter* stream) override { + EchoRequest req; + EchoResponse resp; + auto client_metadata = context->client_metadata(); + for (const auto& pair : client_metadata) { + context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second)); + } + + while (stream->Read(&req)) { + resp.set_message(req.message()); + EXPECT_TRUE(stream->Write(resp, grpc::WriteOptions())); + } + return Status::OK; + } + + Status RequestStream(ServerContext* context, + ServerReader* reader, + EchoResponse* resp) override { + auto client_metadata = context->client_metadata(); + for (const auto& pair : client_metadata) { + context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second)); + } + + EchoRequest req; + string response_str = ""; + while (reader->Read(&req)) { + response_str += req.message(); + } + resp->set_message(response_str); + return Status::OK; + } + + Status ResponseStream(ServerContext* context, const EchoRequest* req, + ServerWriter* writer) override { + auto client_metadata = context->client_metadata(); + for (const auto& pair : client_metadata) { + context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second)); + } + + EchoResponse resp; + resp.set_message(req->message()); + for (int i = 0; i < 10; i++) { + EXPECT_TRUE(writer->Write(resp)); + } + return Status::OK; + } +}; + +void MakeCall(const std::shared_ptr& channel) { + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + req.mutable_param()->set_echo_metadata(true); + ctx.AddMetadata("testkey", "testvalue"); + req.set_message("Hello"); + EchoResponse resp; + Status s = stub->Echo(&ctx, req, &resp); + EXPECT_EQ(s.ok(), true); + EXPECT_EQ(resp.message(), "Hello"); +} + +void MakeClientStreamingCall(const std::shared_ptr& channel) { + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + req.mutable_param()->set_echo_metadata(true); + ctx.AddMetadata("testkey", "testvalue"); + req.set_message("Hello"); + EchoResponse resp; + string expected_resp = ""; + auto writer = stub->RequestStream(&ctx, &resp); + for (int i = 0; i < 10; i++) { + writer->Write(req); + expected_resp += "Hello"; + } + writer->WritesDone(); + Status s = writer->Finish(); + EXPECT_EQ(s.ok(), true); + EXPECT_EQ(resp.message(), expected_resp); +} + +void MakeServerStreamingCall(const std::shared_ptr& channel) { + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + req.mutable_param()->set_echo_metadata(true); + ctx.AddMetadata("testkey", "testvalue"); + req.set_message("Hello"); + EchoResponse resp; + string expected_resp = ""; + auto reader = stub->ResponseStream(&ctx, req); + int count = 0; + while (reader->Read(&resp)) { + EXPECT_EQ(resp.message(), "Hello"); + count++; + } + ASSERT_EQ(count, 10); + Status s = reader->Finish(); + EXPECT_EQ(s.ok(), true); +} + +void MakeBidiStreamingCall(const std::shared_ptr& channel) { + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + EchoResponse resp; + ctx.AddMetadata("testkey", "testvalue"); + auto stream = stub->BidiStream(&ctx); + for (auto i = 0; i < 10; i++) { + req.set_message("Hello" + std::to_string(i)); + stream->Write(req); + stream->Read(&resp); + EXPECT_EQ(req.message(), resp.message()); + } + ASSERT_TRUE(stream->WritesDone()); + Status s = stream->Finish(); + EXPECT_EQ(s.ok(), true); +} + +void MakeCallbackCall(const std::shared_ptr& channel) { + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + std::mutex mu; + std::condition_variable cv; + bool done = false; + req.mutable_param()->set_echo_metadata(true); + ctx.AddMetadata("testkey", "testvalue"); + req.set_message("Hello"); + EchoResponse resp; + stub->experimental_async()->Echo(&ctx, &req, &resp, + [&resp, &mu, &done, &cv](Status s) { + // gpr_log(GPR_ERROR, "got the callback"); + EXPECT_EQ(s.ok(), true); + EXPECT_EQ(resp.message(), "Hello"); + std::lock_guard l(mu); + done = true; + cv.notify_one(); + }); + std::unique_lock l(mu); + while (!done) { + cv.wait(l); + } +} + +bool CheckMetadata(const std::multimap& map, + const string& key, const string& value) { + for (const auto& pair : map) { + if (pair.first.starts_with(key) && pair.second.starts_with(value)) { + return true; + } + } + return false; +} + +void* tag(int i) { return (void*)static_cast(i); } +int detag(void* p) { return static_cast(reinterpret_cast(p)); } + +class Verifier { + public: + Verifier() : lambda_run_(false) {} + // Expect sets the expected ok value for a specific tag + Verifier& Expect(int i, bool expect_ok) { + return ExpectUnless(i, expect_ok, false); + } + // ExpectUnless sets the expected ok value for a specific tag + // unless the tag was already marked seen (as a result of ExpectMaybe) + Verifier& ExpectUnless(int i, bool expect_ok, bool seen) { + if (!seen) { + expectations_[tag(i)] = expect_ok; + } + return *this; + } + // ExpectMaybe sets the expected ok value for a specific tag, but does not + // require it to appear + // If it does, sets *seen to true + Verifier& ExpectMaybe(int i, bool expect_ok, bool* seen) { + if (!*seen) { + maybe_expectations_[tag(i)] = MaybeExpect{expect_ok, seen}; + } + return *this; + } + + // Next waits for 1 async tag to complete, checks its + // expectations, and returns the tag + int Next(CompletionQueue* cq, bool ignore_ok) { + bool ok; + void* got_tag; + EXPECT_TRUE(cq->Next(&got_tag, &ok)); + GotTag(got_tag, ok, ignore_ok); + return detag(got_tag); + } + + template + CompletionQueue::NextStatus DoOnceThenAsyncNext( + CompletionQueue* cq, void** got_tag, bool* ok, T deadline, + std::function lambda) { + if (lambda_run_) { + return cq->AsyncNext(got_tag, ok, deadline); + } else { + lambda_run_ = true; + return cq->DoThenAsyncNext(lambda, got_tag, ok, deadline); + } + } + + // Verify keeps calling Next until all currently set + // expected tags are complete + void Verify(CompletionQueue* cq) { Verify(cq, false); } + + // This version of Verify allows optionally ignoring the + // outcome of the expectation + void Verify(CompletionQueue* cq, bool ignore_ok) { + GPR_ASSERT(!expectations_.empty() || !maybe_expectations_.empty()); + while (!expectations_.empty()) { + Next(cq, ignore_ok); + } + } + + // This version of Verify stops after a certain deadline, and uses the + // DoThenAsyncNext API + // to call the lambda + void Verify(CompletionQueue* cq, + std::chrono::system_clock::time_point deadline, + const std::function& lambda) { + if (expectations_.empty()) { + bool ok; + void* got_tag; + EXPECT_EQ(DoOnceThenAsyncNext(cq, &got_tag, &ok, deadline, lambda), + CompletionQueue::TIMEOUT); + } else { + while (!expectations_.empty()) { + bool ok; + void* got_tag; + EXPECT_EQ(DoOnceThenAsyncNext(cq, &got_tag, &ok, deadline, lambda), + CompletionQueue::GOT_EVENT); + GotTag(got_tag, ok, false); + } + } + } + + private: + void GotTag(void* got_tag, bool ok, bool ignore_ok) { + auto it = expectations_.find(got_tag); + if (it != expectations_.end()) { + if (!ignore_ok) { + EXPECT_EQ(it->second, ok); + } + expectations_.erase(it); + } else { + auto it2 = maybe_expectations_.find(got_tag); + if (it2 != maybe_expectations_.end()) { + if (it2->second.seen != nullptr) { + EXPECT_FALSE(*it2->second.seen); + *it2->second.seen = true; + } + if (!ignore_ok) { + EXPECT_EQ(it2->second.ok, ok); + } + } else { + gpr_log(GPR_ERROR, "Unexpected tag: %p", got_tag); + abort(); + } + } + } + + struct MaybeExpect { + bool ok; + bool* seen; + }; + + std::map expectations_; + std::map maybe_expectations_; + bool lambda_run_; +}; + +} // namespace testing +} // namespace grpc diff --git a/test/cpp/end2end/server_interceptors_end2end_test.cc b/test/cpp/end2end/server_interceptors_end2end_test.cc new file mode 100644 index 00000000000..44ba2a60093 --- /dev/null +++ b/test/cpp/end2end/server_interceptors_end2end_test.cc @@ -0,0 +1,623 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/interceptors_util.h" +#include "test/cpp/end2end/test_service_impl.h" +#include "test/cpp/util/byte_buffer_proto_helper.h" + +#include + +namespace grpc { +namespace testing { +namespace { + +/* This interceptor does nothing. Just keeps a global count on the number of + * times it was invoked. */ +class DummyInterceptor : public experimental::Interceptor { + public: + DummyInterceptor(experimental::ServerRpcInfo* info) {} + + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + num_times_run_++; + } else if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints:: + POST_RECV_INITIAL_METADATA)) { + num_times_run_reverse_++; + } + methods->Proceed(); + } + + static void Reset() { + num_times_run_.store(0); + num_times_run_reverse_.store(0); + } + + static int GetNumTimesRun() { + EXPECT_EQ(num_times_run_.load(), num_times_run_reverse_.load()); + return num_times_run_.load(); + } + + private: + static std::atomic num_times_run_; + static std::atomic num_times_run_reverse_; +}; + +std::atomic DummyInterceptor::num_times_run_; +std::atomic DummyInterceptor::num_times_run_reverse_; + +class DummyInterceptorFactory + : public experimental::ServerInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateServerInterceptor( + experimental::ServerRpcInfo* info) override { + return new DummyInterceptor(info); + } +}; + +class LoggingInterceptor : public experimental::Interceptor { + public: + LoggingInterceptor(experimental::ServerRpcInfo* info) { info_ = info; } + + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + auto* map = methods->GetSendInitialMetadata(); + // Got nothing better to do here for now + EXPECT_EQ(map->size(), static_cast(0)); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + EchoRequest req; + auto* buffer = methods->GetSendMessage(); + auto copied_buffer = *buffer; + SerializationTraits::Deserialize(&copied_buffer, &req); + EXPECT_TRUE(req.message().find("Hello") == 0); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_STATUS)) { + auto* map = methods->GetSendTrailingMetadata(); + bool found = false; + // Check that we received the metadata as an echo + for (const auto& pair : *map) { + found = pair.first.find("testkey") == 0 && + pair.second.find("testvalue") == 0; + if (found) break; + } + EXPECT_EQ(found, true); + auto status = methods->GetSendStatus(); + EXPECT_EQ(status.ok(), true); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) { + auto* map = methods->GetRecvInitialMetadata(); + bool found = false; + // Check that we received the metadata as an echo + for (const auto& pair : *map) { + found = pair.first.find("testkey") == 0 && + pair.second.find("testvalue") == 0; + if (found) break; + } + EXPECT_EQ(found, true); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { + EchoResponse* resp = + static_cast(methods->GetRecvMessage()); + EXPECT_TRUE(resp->message().find("Hello") == 0); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_CLOSE)) { + // Got nothing interesting to do here + } + methods->Proceed(); + } + + private: + experimental::ServerRpcInfo* info_; +}; + +class LoggingInterceptorFactory + : public experimental::ServerInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateServerInterceptor( + experimental::ServerRpcInfo* info) override { + return new LoggingInterceptor(info); + } +}; + +void MakeBidiStreamingCall(const std::shared_ptr& channel) { + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + EchoResponse resp; + ctx.AddMetadata("testkey", "testvalue"); + auto stream = stub->BidiStream(&ctx); + for (auto i = 0; i < 10; i++) { + req.set_message("Hello" + std::to_string(i)); + stream->Write(req); + stream->Read(&resp); + EXPECT_EQ(req.message(), resp.message()); + } + ASSERT_TRUE(stream->WritesDone()); + Status s = stream->Finish(); + EXPECT_EQ(s.ok(), true); +} + +class ServerInterceptorsEnd2endSyncUnaryTest : public ::testing::Test { + protected: + ServerInterceptorsEnd2endSyncUnaryTest() { + int port = grpc_pick_unused_port_or_die(); + + ServerBuilder builder; + server_address_ = "localhost:" + std::to_string(port); + builder.AddListeningPort(server_address_, InsecureServerCredentials()); + builder.RegisterService(&service_); + + std::vector< + std::unique_ptr> + creators; + creators.push_back( + std::unique_ptr( + new LoggingInterceptorFactory())); + for (auto i = 0; i < 20; i++) { + creators.push_back(std::unique_ptr( + new DummyInterceptorFactory())); + } + builder.experimental().SetInterceptorCreators(std::move(creators)); + server_ = builder.BuildAndStart(); + } + std::string server_address_; + TestServiceImpl service_; + std::unique_ptr server_; +}; + +TEST_F(ServerInterceptorsEnd2endSyncUnaryTest, UnaryTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + auto channel = CreateChannel(server_address_, InsecureChannelCredentials()); + MakeCall(channel); + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); +} + +class ServerInterceptorsEnd2endSyncStreamingTest : public ::testing::Test { + protected: + ServerInterceptorsEnd2endSyncStreamingTest() { + int port = grpc_pick_unused_port_or_die(); + + ServerBuilder builder; + server_address_ = "localhost:" + std::to_string(port); + builder.AddListeningPort(server_address_, InsecureServerCredentials()); + builder.RegisterService(&service_); + + std::vector< + std::unique_ptr> + creators; + creators.push_back( + std::unique_ptr( + new LoggingInterceptorFactory())); + for (auto i = 0; i < 20; i++) { + creators.push_back(std::unique_ptr( + new DummyInterceptorFactory())); + } + builder.experimental().SetInterceptorCreators(std::move(creators)); + server_ = builder.BuildAndStart(); + } + std::string server_address_; + EchoTestServiceStreamingImpl service_; + std::unique_ptr server_; +}; + +TEST_F(ServerInterceptorsEnd2endSyncStreamingTest, ClientStreamingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + auto channel = CreateChannel(server_address_, InsecureChannelCredentials()); + MakeClientStreamingCall(channel); + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); +} + +TEST_F(ServerInterceptorsEnd2endSyncStreamingTest, ServerStreamingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + auto channel = CreateChannel(server_address_, InsecureChannelCredentials()); + MakeServerStreamingCall(channel); + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); +} + +TEST_F(ServerInterceptorsEnd2endSyncStreamingTest, BidiStreamingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + auto channel = CreateChannel(server_address_, InsecureChannelCredentials()); + MakeBidiStreamingCall(channel); + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); +} + +class ServerInterceptorsAsyncEnd2endTest : public ::testing::Test {}; + +TEST_F(ServerInterceptorsAsyncEnd2endTest, UnaryTest) { + DummyInterceptor::Reset(); + int port = grpc_pick_unused_port_or_die(); + string server_address = "localhost:" + std::to_string(port); + ServerBuilder builder; + EchoTestService::AsyncService service; + builder.AddListeningPort(server_address, InsecureServerCredentials()); + builder.RegisterService(&service); + std::vector> + creators; + creators.push_back( + std::unique_ptr( + new LoggingInterceptorFactory())); + for (auto i = 0; i < 20; i++) { + creators.push_back(std::unique_ptr( + new DummyInterceptorFactory())); + } + builder.experimental().SetInterceptorCreators(std::move(creators)); + auto cq = builder.AddCompletionQueue(); + auto server = builder.BuildAndStart(); + + ChannelArguments args; + auto channel = CreateChannel(server_address, InsecureChannelCredentials()); + auto stub = grpc::testing::EchoTestService::NewStub(channel); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + ServerContext srv_ctx; + grpc::ServerAsyncResponseWriter response_writer(&srv_ctx); + + send_request.set_message("Hello"); + cli_ctx.AddMetadata("testkey", "testvalue"); + std::unique_ptr> response_reader( + stub->AsyncEcho(&cli_ctx, send_request, cq.get())); + + service.RequestEcho(&srv_ctx, &recv_request, &response_writer, cq.get(), + cq.get(), tag(2)); + + response_reader->Finish(&recv_response, &recv_status, tag(4)); + + Verifier().Expect(2, true).Verify(cq.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + EXPECT_TRUE(CheckMetadata(srv_ctx.client_metadata(), "testkey", "testvalue")); + srv_ctx.AddTrailingMetadata("testkey", "testvalue"); + + send_response.set_message(recv_request.message()); + response_writer.Finish(send_response, Status::OK, tag(3)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq.get()); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); + EXPECT_TRUE(CheckMetadata(cli_ctx.GetServerTrailingMetadata(), "testkey", + "testvalue")); + + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); + + server->Shutdown(); + cq->Shutdown(); + void* ignored_tag; + bool ignored_ok; + while (cq->Next(&ignored_tag, &ignored_ok)) + ; + grpc_recycle_unused_port(port); +} + +TEST_F(ServerInterceptorsAsyncEnd2endTest, BidiStreamingTest) { + DummyInterceptor::Reset(); + int port = grpc_pick_unused_port_or_die(); + string server_address = "localhost:" + std::to_string(port); + ServerBuilder builder; + EchoTestService::AsyncService service; + builder.AddListeningPort(server_address, InsecureServerCredentials()); + builder.RegisterService(&service); + std::vector> + creators; + creators.push_back( + std::unique_ptr( + new LoggingInterceptorFactory())); + for (auto i = 0; i < 20; i++) { + creators.push_back(std::unique_ptr( + new DummyInterceptorFactory())); + } + builder.experimental().SetInterceptorCreators(std::move(creators)); + auto cq = builder.AddCompletionQueue(); + auto server = builder.BuildAndStart(); + + ChannelArguments args; + auto channel = CreateChannel(server_address, InsecureChannelCredentials()); + auto stub = grpc::testing::EchoTestService::NewStub(channel); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + ServerContext srv_ctx; + grpc::ServerAsyncReaderWriter srv_stream(&srv_ctx); + + send_request.set_message("Hello"); + cli_ctx.AddMetadata("testkey", "testvalue"); + std::unique_ptr> + cli_stream(stub->AsyncBidiStream(&cli_ctx, cq.get(), tag(1))); + + service.RequestBidiStream(&srv_ctx, &srv_stream, cq.get(), cq.get(), tag(2)); + + Verifier().Expect(1, true).Expect(2, true).Verify(cq.get()); + + EXPECT_TRUE(CheckMetadata(srv_ctx.client_metadata(), "testkey", "testvalue")); + srv_ctx.AddTrailingMetadata("testkey", "testvalue"); + + cli_stream->Write(send_request, tag(3)); + srv_stream.Read(&recv_request, tag(4)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + srv_stream.Write(send_response, tag(5)); + cli_stream->Read(&recv_response, tag(6)); + Verifier().Expect(5, true).Expect(6, true).Verify(cq.get()); + EXPECT_EQ(send_response.message(), recv_response.message()); + + cli_stream->WritesDone(tag(7)); + srv_stream.Read(&recv_request, tag(8)); + Verifier().Expect(7, true).Expect(8, false).Verify(cq.get()); + + srv_stream.Finish(Status::OK, tag(9)); + cli_stream->Finish(&recv_status, tag(10)); + Verifier().Expect(9, true).Expect(10, true).Verify(cq.get()); + + EXPECT_TRUE(recv_status.ok()); + EXPECT_TRUE(CheckMetadata(cli_ctx.GetServerTrailingMetadata(), "testkey", + "testvalue")); + + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); + + server->Shutdown(); + cq->Shutdown(); + void* ignored_tag; + bool ignored_ok; + while (cq->Next(&ignored_tag, &ignored_ok)) + ; + grpc_recycle_unused_port(port); +} + +TEST_F(ServerInterceptorsAsyncEnd2endTest, GenericRPCTest) { + DummyInterceptor::Reset(); + int port = grpc_pick_unused_port_or_die(); + string server_address = "localhost:" + std::to_string(port); + ServerBuilder builder; + AsyncGenericService service; + builder.AddListeningPort(server_address, InsecureServerCredentials()); + builder.RegisterAsyncGenericService(&service); + std::vector> + creators; + for (auto i = 0; i < 20; i++) { + creators.push_back(std::unique_ptr( + new DummyInterceptorFactory())); + } + builder.experimental().SetInterceptorCreators(std::move(creators)); + auto cq = builder.AddCompletionQueue(); + auto server = builder.BuildAndStart(); + + ChannelArguments args; + auto channel = CreateChannel(server_address, InsecureChannelCredentials()); + GenericStub generic_stub(channel); + + const grpc::string kMethodName("/grpc.cpp.test.util.EchoTestService/Echo"); + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + GenericServerContext srv_ctx; + GenericServerAsyncReaderWriter stream(&srv_ctx); + + // The string needs to be long enough to test heap-based slice. + send_request.set_message("Hello"); + cli_ctx.AddMetadata("testkey", "testvalue"); + + std::unique_ptr call = + generic_stub.PrepareCall(&cli_ctx, kMethodName, cq.get()); + call->StartCall(tag(1)); + Verifier().Expect(1, true).Verify(cq.get()); + std::unique_ptr send_buffer = + SerializeToByteBuffer(&send_request); + call->Write(*send_buffer, tag(2)); + // Send ByteBuffer can be destroyed after calling Write. + send_buffer.reset(); + Verifier().Expect(2, true).Verify(cq.get()); + call->WritesDone(tag(3)); + Verifier().Expect(3, true).Verify(cq.get()); + + service.RequestCall(&srv_ctx, &stream, cq.get(), cq.get(), tag(4)); + + Verifier().Expect(4, true).Verify(cq.get()); + EXPECT_EQ(kMethodName, srv_ctx.method()); + EXPECT_TRUE(CheckMetadata(srv_ctx.client_metadata(), "testkey", "testvalue")); + srv_ctx.AddTrailingMetadata("testkey", "testvalue"); + + ByteBuffer recv_buffer; + stream.Read(&recv_buffer, tag(5)); + Verifier().Expect(5, true).Verify(cq.get()); + EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request)); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + send_buffer = SerializeToByteBuffer(&send_response); + stream.Write(*send_buffer, tag(6)); + send_buffer.reset(); + Verifier().Expect(6, true).Verify(cq.get()); + + stream.Finish(Status::OK, tag(7)); + Verifier().Expect(7, true).Verify(cq.get()); + + recv_buffer.Clear(); + call->Read(&recv_buffer, tag(8)); + Verifier().Expect(8, true).Verify(cq.get()); + EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_response)); + + call->Finish(&recv_status, tag(9)); + Verifier().Expect(9, true).Verify(cq.get()); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); + EXPECT_TRUE(CheckMetadata(cli_ctx.GetServerTrailingMetadata(), "testkey", + "testvalue")); + + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); + + server->Shutdown(); + cq->Shutdown(); + void* ignored_tag; + bool ignored_ok; + while (cq->Next(&ignored_tag, &ignored_ok)) + ; + grpc_recycle_unused_port(port); +} + +TEST_F(ServerInterceptorsAsyncEnd2endTest, UnimplementedRpcTest) { + DummyInterceptor::Reset(); + int port = grpc_pick_unused_port_or_die(); + string server_address = "localhost:" + std::to_string(port); + ServerBuilder builder; + builder.AddListeningPort(server_address, InsecureServerCredentials()); + std::vector> + creators; + for (auto i = 0; i < 20; i++) { + creators.push_back(std::unique_ptr( + new DummyInterceptorFactory())); + } + builder.experimental().SetInterceptorCreators(std::move(creators)); + auto cq = builder.AddCompletionQueue(); + auto server = builder.BuildAndStart(); + + ChannelArguments args; + std::shared_ptr channel = + CreateChannel(server_address, InsecureChannelCredentials()); + std::unique_ptr stub; + stub = grpc::testing::UnimplementedEchoService::NewStub(channel); + EchoRequest send_request; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + send_request.set_message("Hello"); + std::unique_ptr> response_reader( + stub->AsyncUnimplemented(&cli_ctx, send_request, cq.get())); + + response_reader->Finish(&recv_response, &recv_status, tag(4)); + Verifier().Expect(4, true).Verify(cq.get()); + + EXPECT_EQ(StatusCode::UNIMPLEMENTED, recv_status.error_code()); + EXPECT_EQ("", recv_status.error_message()); + + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); + + server->Shutdown(); + cq->Shutdown(); + void* ignored_tag; + bool ignored_ok; + while (cq->Next(&ignored_tag, &ignored_ok)) + ; + grpc_recycle_unused_port(port); +} + +class ServerInterceptorsSyncUnimplementedEnd2endTest : public ::testing::Test { +}; + +TEST_F(ServerInterceptorsSyncUnimplementedEnd2endTest, UnimplementedRpcTest) { + DummyInterceptor::Reset(); + int port = grpc_pick_unused_port_or_die(); + string server_address = "localhost:" + std::to_string(port); + ServerBuilder builder; + TestServiceImpl service; + builder.RegisterService(&service); + builder.AddListeningPort(server_address, InsecureServerCredentials()); + std::vector> + creators; + for (auto i = 0; i < 20; i++) { + creators.push_back(std::unique_ptr( + new DummyInterceptorFactory())); + } + builder.experimental().SetInterceptorCreators(std::move(creators)); + auto server = builder.BuildAndStart(); + + ChannelArguments args; + std::shared_ptr channel = + CreateChannel(server_address, InsecureChannelCredentials()); + std::unique_ptr stub; + stub = grpc::testing::UnimplementedEchoService::NewStub(channel); + EchoRequest send_request; + EchoResponse recv_response; + + ClientContext cli_ctx; + send_request.set_message("Hello"); + Status recv_status = + stub->Unimplemented(&cli_ctx, send_request, &recv_response); + + EXPECT_EQ(StatusCode::UNIMPLEMENTED, recv_status.error_code()); + EXPECT_EQ("", recv_status.error_message()); + + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); + + server->Shutdown(); + grpc_recycle_unused_port(port); +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc_test_init(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/interop/client_helper.h b/test/cpp/interop/client_helper.h index eada2f671f3..7dee85cc980 100644 --- a/test/cpp/interop/client_helper.h +++ b/test/cpp/interop/client_helper.h @@ -19,10 +19,12 @@ #ifndef GRPC_TEST_CPP_INTEROP_CLIENT_HELPER_H #define GRPC_TEST_CPP_INTEROP_CLIENT_HELPER_H +#include #include #include #include +#include #include "src/core/lib/surface/call_test_only.h" diff --git a/tools/doxygen/Doxyfile.c++ b/tools/doxygen/Doxyfile.c++ index 1abf6bddce1..f2bb5df7d33 100644 --- a/tools/doxygen/Doxyfile.c++ +++ b/tools/doxygen/Doxyfile.c++ @@ -943,6 +943,8 @@ include/grpcpp/impl/codegen/async_unary_call.h \ include/grpcpp/impl/codegen/byte_buffer.h \ include/grpcpp/impl/codegen/call.h \ include/grpcpp/impl/codegen/call_hook.h \ +include/grpcpp/impl/codegen/call_op_set.h \ +include/grpcpp/impl/codegen/call_op_set_interface.h \ include/grpcpp/impl/codegen/callback_common.h \ include/grpcpp/impl/codegen/channel_interface.h \ include/grpcpp/impl/codegen/client_callback.h \ @@ -957,7 +959,9 @@ include/grpcpp/impl/codegen/core_codegen.h \ include/grpcpp/impl/codegen/core_codegen_interface.h \ include/grpcpp/impl/codegen/create_auth_context.h \ include/grpcpp/impl/codegen/grpc_library.h \ +include/grpcpp/impl/codegen/intercepted_channel.h \ include/grpcpp/impl/codegen/interceptor.h \ +include/grpcpp/impl/codegen/interceptor_common.h \ include/grpcpp/impl/codegen/metadata_map.h \ include/grpcpp/impl/codegen/method_handler_impl.h \ include/grpcpp/impl/codegen/proto_buffer_reader.h \ @@ -968,6 +972,7 @@ include/grpcpp/impl/codegen/rpc_service_method.h \ include/grpcpp/impl/codegen/security/auth_context.h \ include/grpcpp/impl/codegen/serialization_traits.h \ include/grpcpp/impl/codegen/server_context.h \ +include/grpcpp/impl/codegen/server_interceptor.h \ include/grpcpp/impl/codegen/server_interface.h \ include/grpcpp/impl/codegen/service_type.h \ include/grpcpp/impl/codegen/slice.h \ diff --git a/tools/doxygen/Doxyfile.c++.internal b/tools/doxygen/Doxyfile.c++.internal index 26e671668ea..81d74cd9724 100644 --- a/tools/doxygen/Doxyfile.c++.internal +++ b/tools/doxygen/Doxyfile.c++.internal @@ -944,6 +944,8 @@ include/grpcpp/impl/codegen/async_unary_call.h \ include/grpcpp/impl/codegen/byte_buffer.h \ include/grpcpp/impl/codegen/call.h \ include/grpcpp/impl/codegen/call_hook.h \ +include/grpcpp/impl/codegen/call_op_set.h \ +include/grpcpp/impl/codegen/call_op_set_interface.h \ include/grpcpp/impl/codegen/callback_common.h \ include/grpcpp/impl/codegen/channel_interface.h \ include/grpcpp/impl/codegen/client_callback.h \ @@ -959,7 +961,9 @@ include/grpcpp/impl/codegen/core_codegen.h \ include/grpcpp/impl/codegen/core_codegen_interface.h \ include/grpcpp/impl/codegen/create_auth_context.h \ include/grpcpp/impl/codegen/grpc_library.h \ +include/grpcpp/impl/codegen/intercepted_channel.h \ include/grpcpp/impl/codegen/interceptor.h \ +include/grpcpp/impl/codegen/interceptor_common.h \ include/grpcpp/impl/codegen/metadata_map.h \ include/grpcpp/impl/codegen/method_handler_impl.h \ include/grpcpp/impl/codegen/proto_buffer_reader.h \ @@ -970,6 +974,7 @@ include/grpcpp/impl/codegen/rpc_service_method.h \ include/grpcpp/impl/codegen/security/auth_context.h \ include/grpcpp/impl/codegen/serialization_traits.h \ include/grpcpp/impl/codegen/server_context.h \ +include/grpcpp/impl/codegen/server_interceptor.h \ include/grpcpp/impl/codegen/server_interface.h \ include/grpcpp/impl/codegen/service_type.h \ include/grpcpp/impl/codegen/slice.h \ diff --git a/tools/run_tests/generated/sources_and_headers.json b/tools/run_tests/generated/sources_and_headers.json index 3a442864e8f..9d7d487e0df 100644 --- a/tools/run_tests/generated/sources_and_headers.json +++ b/tools/run_tests/generated/sources_and_headers.json @@ -3385,6 +3385,28 @@ "third_party": false, "type": "target" }, + { + "deps": [ + "gpr", + "gpr_test_util", + "grpc", + "grpc++", + "grpc++_test_util", + "grpc_test_util" + ], + "headers": [ + "test/cpp/end2end/interceptors_util.h" + ], + "is_filegroup": false, + "language": "c++", + "name": "client_interceptors_end2end_test", + "src": [ + "test/cpp/end2end/client_interceptors_end2end_test.cc", + "test/cpp/end2end/interceptors_util.h" + ], + "third_party": false, + "type": "target" + }, { "deps": [ "gpr", @@ -11154,6 +11176,8 @@ "include/grpcpp/impl/codegen/byte_buffer.h", "include/grpcpp/impl/codegen/call.h", "include/grpcpp/impl/codegen/call_hook.h", + "include/grpcpp/impl/codegen/call_op_set.h", + "include/grpcpp/impl/codegen/call_op_set_interface.h", "include/grpcpp/impl/codegen/callback_common.h", "include/grpcpp/impl/codegen/channel_interface.h", "include/grpcpp/impl/codegen/client_callback.h", @@ -11166,7 +11190,9 @@ "include/grpcpp/impl/codegen/core_codegen_interface.h", "include/grpcpp/impl/codegen/create_auth_context.h", "include/grpcpp/impl/codegen/grpc_library.h", + "include/grpcpp/impl/codegen/intercepted_channel.h", "include/grpcpp/impl/codegen/interceptor.h", + "include/grpcpp/impl/codegen/interceptor_common.h", "include/grpcpp/impl/codegen/metadata_map.h", "include/grpcpp/impl/codegen/method_handler_impl.h", "include/grpcpp/impl/codegen/rpc_method.h", @@ -11174,6 +11200,7 @@ "include/grpcpp/impl/codegen/security/auth_context.h", "include/grpcpp/impl/codegen/serialization_traits.h", "include/grpcpp/impl/codegen/server_context.h", + "include/grpcpp/impl/codegen/server_interceptor.h", "include/grpcpp/impl/codegen/server_interface.h", "include/grpcpp/impl/codegen/service_type.h", "include/grpcpp/impl/codegen/slice.h", @@ -11224,6 +11251,8 @@ "include/grpcpp/impl/codegen/byte_buffer.h", "include/grpcpp/impl/codegen/call.h", "include/grpcpp/impl/codegen/call_hook.h", + "include/grpcpp/impl/codegen/call_op_set.h", + "include/grpcpp/impl/codegen/call_op_set_interface.h", "include/grpcpp/impl/codegen/callback_common.h", "include/grpcpp/impl/codegen/channel_interface.h", "include/grpcpp/impl/codegen/client_callback.h", @@ -11236,7 +11265,9 @@ "include/grpcpp/impl/codegen/core_codegen_interface.h", "include/grpcpp/impl/codegen/create_auth_context.h", "include/grpcpp/impl/codegen/grpc_library.h", + "include/grpcpp/impl/codegen/intercepted_channel.h", "include/grpcpp/impl/codegen/interceptor.h", + "include/grpcpp/impl/codegen/interceptor_common.h", "include/grpcpp/impl/codegen/metadata_map.h", "include/grpcpp/impl/codegen/method_handler_impl.h", "include/grpcpp/impl/codegen/rpc_method.h", @@ -11244,6 +11275,7 @@ "include/grpcpp/impl/codegen/security/auth_context.h", "include/grpcpp/impl/codegen/serialization_traits.h", "include/grpcpp/impl/codegen/server_context.h", + "include/grpcpp/impl/codegen/server_interceptor.h", "include/grpcpp/impl/codegen/server_interface.h", "include/grpcpp/impl/codegen/service_type.h", "include/grpcpp/impl/codegen/slice.h", diff --git a/tools/run_tests/generated/tests.json b/tools/run_tests/generated/tests.json index 145d2526e8d..10526666341 100644 --- a/tools/run_tests/generated/tests.json +++ b/tools/run_tests/generated/tests.json @@ -4027,6 +4027,30 @@ ], "uses_polling": true }, + { + "args": [], + "benchmark": false, + "ci_platforms": [ + "linux", + "mac", + "posix", + "windows" + ], + "cpu_cost": 0.5, + "exclude_configs": [], + "exclude_iomgrs": [], + "flaky": false, + "gtest": true, + "language": "c++", + "name": "client_interceptors_end2end_test", + "platforms": [ + "linux", + "mac", + "posix", + "windows" + ], + "uses_polling": true + }, { "args": [], "benchmark": false,