diff --git a/BUILDING.md b/BUILDING.md index 1ed79bbcaca..9addae90fc7 100644 --- a/BUILDING.md +++ b/BUILDING.md @@ -161,10 +161,11 @@ Please note that when using Ninja, you will still need Visual C++ (part of Visua installed to be able to compile the C/C++ sources. ``` > @rem Run from grpc directory after cloning the repo with --recursive or updating submodules. -> md .build -> cd .build +> cd cmake +> md build +> cd build > call "%VS140COMNTOOLS%..\..\VC\vcvarsall.bat" x64 -> cmake .. -GNinja -DCMAKE_BUILD_TYPE=Release +> cmake ..\.. -GNinja -DCMAKE_BUILD_TYPE=Release > cmake --build . ``` @@ -183,7 +184,7 @@ ie `gRPC_CARES_PROVIDER`. ### Install after build Perform the following steps to install gRPC using CMake. -* Set `gRPC_INSTALL` to `ON` +* Set `-DgRPC_INSTALL=ON` * Build the `install` target The install destination is controlled by the @@ -196,16 +197,21 @@ in "module" mode and install them alongside gRPC in a single step. If you are using an older version of gRPC, you will need to select "package" mode (rather than "module" mode) for the dependencies. This means you will need to have external copies of these libraries available -on your system. +on your system. This [example](test/distrib/cpp/run_distrib_test_cmake.sh) shows +how to install dependencies with cmake before proceeding to installing gRPC itself. + ``` -$ cmake .. -DgRPC_CARES_PROVIDER=package \ - -DgRPC_PROTOBUF_PROVIDER=package \ - -DgRPC_SSL_PROVIDER=package \ - -DgRPC_ZLIB_PROVIDER=package +# NOTE: all of gRPC's dependencies need to be already installed +$ cmake ../.. -DgRPC_INSTALL=ON \ + -DCMAKE_BUILD_TYPE=Release \ + -DgRPC_ABSL_PROVIDER=package \ + -DgRPC_CARES_PROVIDER=package \ + -DgRPC_PROTOBUF_PROVIDER=package \ + -DgRPC_SSL_PROVIDER=package \ + -DgRPC_ZLIB_PROVIDER=package $ make $ make install ``` -[Example](test/distrib/cpp/run_distrib_test_cmake.sh) ### Cross-compiling @@ -222,7 +228,7 @@ that will be used for this build. This toolchain file is specified to CMake by setting the `CMAKE_TOOLCHAIN_FILE` variable. ``` -$ cmake .. -DCMAKE_TOOLCHAIN_FILE=path/to/file +$ cmake ../.. -DCMAKE_TOOLCHAIN_FILE=path/to/file $ make ``` diff --git a/bazel/cc_grpc_library.bzl b/bazel/cc_grpc_library.bzl index dea493eaf20..7ec1a98ec51 100644 --- a/bazel/cc_grpc_library.bzl +++ b/bazel/cc_grpc_library.bzl @@ -1,5 +1,6 @@ """Generates and compiles C++ grpc stubs from proto_library rules.""" +load("@rules_proto//proto:defs.bzl", "proto_library") load("//bazel:generate_cc.bzl", "generate_cc") load("//bazel:protobuf.bzl", "well_known_proto_libs") @@ -63,8 +64,7 @@ def cc_grpc_library( proto_deps += [dep.split(":")[0] + ":" + "_" + dep.split(":")[1] + "_only" for dep in deps if dep.find(":") != -1] if well_known_protos: proto_deps += well_known_proto_libs() - - native.proto_library( + proto_library( name = proto_target, srcs = srcs, deps = proto_deps, diff --git a/bazel/generate_cc.bzl b/bazel/generate_cc.bzl index 484959ebb70..a1808630217 100644 --- a/bazel/generate_cc.bzl +++ b/bazel/generate_cc.bzl @@ -4,6 +4,7 @@ This is an internal rule used by cc_grpc_library, and shouldn't be used directly. """ +load("@rules_proto//proto:defs.bzl", "ProtoInfo") load( "//bazel:protobuf.bzl", "get_include_directory", diff --git a/bazel/generate_objc.bzl b/bazel/generate_objc.bzl index 3bf5aa39243..cffe4043992 100644 --- a/bazel/generate_objc.bzl +++ b/bazel/generate_objc.bzl @@ -1,3 +1,4 @@ +load("@rules_proto//proto:defs.bzl", "ProtoInfo") load( "//bazel:protobuf.bzl", "get_include_directory", diff --git a/bazel/protobuf.bzl b/bazel/protobuf.bzl index 7af27a8b308..330301e930c 100644 --- a/bazel/protobuf.bzl +++ b/bazel/protobuf.bzl @@ -1,5 +1,7 @@ """Utility functions for generating protobuf code.""" +load("@rules_proto//proto:defs.bzl", "ProtoInfo") + _PROTO_EXTENSION = ".proto" _VIRTUAL_IMPORTS = "/_virtual_imports/" diff --git a/bazel/python_rules.bzl b/bazel/python_rules.bzl index e8ada92f15f..39fee5d4d34 100644 --- a/bazel/python_rules.bzl +++ b/bazel/python_rules.bzl @@ -1,5 +1,6 @@ """Generates and compiles Python gRPC stubs from proto_library rules.""" +load("@rules_proto//proto:defs.bzl", "ProtoInfo") load( "//bazel:protobuf.bzl", "declare_out_files", diff --git a/bazel/test/python_test_repo/BUILD b/bazel/test/python_test_repo/BUILD index da6d135d8e7..fb23cf3370f 100644 --- a/bazel/test/python_test_repo/BUILD +++ b/bazel/test/python_test_repo/BUILD @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@rules_proto//proto:defs.bzl", "proto_library") load( "@com_github_grpc_grpc//bazel:python_rules.bzl", "py2and3_test", diff --git a/doc/environment_variables.md b/doc/environment_variables.md index ff6eaba57b7..aef14a9842c 100644 --- a/doc/environment_variables.md +++ b/doc/environment_variables.md @@ -4,8 +4,14 @@ gRPC environment variables gRPC C core based implementations (those contained in this repository) expose some configuration as environment variables that can be set. -* http_proxy - The URI of the proxy to use for HTTP CONNECT support. +* grpc_proxy, https_proxy, http_proxy + The URI of the proxy to use for HTTP CONNECT support. These variables are + checked in order, and the first one that has a value is used. + +* no_grpc_proxy, no_proxy + A comma separated list of hostnames to connect to without using a proxy even + if a proxy is set. These variables are checked in order, and the first one + that has a value is used. * GRPC_ABORT_ON_LEAKS A debugging aid to cause a call to abort() when gRPC objects are leaked past diff --git a/examples/BUILD b/examples/BUILD index f6dae42ca4c..b6458b74b83 100644 --- a/examples/BUILD +++ b/examples/BUILD @@ -16,10 +16,11 @@ licenses(["notice"]) # 3-clause BSD package(default_visibility = ["//visibility:public"]) -load("//bazel:grpc_build_system.bzl", "grpc_proto_library") +load("@grpc_python_dependencies//:requirements.bzl", "requirement") +load("@rules_proto//proto:defs.bzl", "proto_library") load("//bazel:cc_grpc_library.bzl", "cc_grpc_library") +load("//bazel:grpc_build_system.bzl", "grpc_proto_library") load("//bazel:python_rules.bzl", "py_grpc_library", "py_proto_library") -load("@grpc_python_dependencies//:requirements.bzl", "requirement") grpc_proto_library( name = "auth_sample", diff --git a/examples/cpp/helloworld/README.md b/examples/cpp/helloworld/README.md index 813a80f288f..584c46ec40d 100644 --- a/examples/cpp/helloworld/README.md +++ b/examples/cpp/helloworld/README.md @@ -255,6 +255,10 @@ main loop in `HandleRpcs` to query the queue. For a working example, refer to [greeter_async_server.cc](greeter_async_server.cc). +#### Flags for the client +```sh +./greeter_client --target="a target string used to create a GRPC client channel" +``` - +The Default value for --target is "localhost:50051". diff --git a/examples/cpp/helloworld/greeter_client.cc b/examples/cpp/helloworld/greeter_client.cc index 932583c84ab..7ece0330c51 100644 --- a/examples/cpp/helloworld/greeter_client.cc +++ b/examples/cpp/helloworld/greeter_client.cc @@ -73,11 +73,32 @@ class GreeterClient { int main(int argc, char** argv) { // Instantiate the client. It requires a channel, out of which the actual RPCs - // are created. This channel models a connection to an endpoint (in this case, - // localhost at port 50051). We indicate that the channel isn't authenticated - // (use of InsecureChannelCredentials()). + // are created. This channel models a connection to an endpoint specified by + // the argument "--target=" which is the only expected argument. + // We indicate that the channel isn't authenticated (use of + // InsecureChannelCredentials()). + std::string target_str; + std::string arg_str("--target"); + if (argc > 1) { + std::string arg_val = argv[1]; + size_t start_pos = arg_val.find(arg_str); + if (start_pos != std::string::npos) { + start_pos += arg_str.size(); + if (arg_val[start_pos] == '=') { + target_str = arg_val.substr(start_pos + 1); + } else { + std::cout << "The only correct argument syntax is --target=" << std::endl; + return 0; + } + } else { + std::cout << "The only acceptable argument is --target=" << std::endl; + return 0; + } + } else { + target_str = "localhost:50051"; + } GreeterClient greeter(grpc::CreateChannel( - "localhost:50051", grpc::InsecureChannelCredentials())); + target_str, grpc::InsecureChannelCredentials())); std::string user("world"); std::string reply = greeter.SayHello(user); std::cout << "Greeter received: " << reply << std::endl; diff --git a/examples/python/cancellation/BUILD.bazel b/examples/python/cancellation/BUILD.bazel index 5215991c5ff..0426cf0b943 100644 --- a/examples/python/cancellation/BUILD.bazel +++ b/examples/python/cancellation/BUILD.bazel @@ -15,6 +15,7 @@ # limitations under the License. load("@grpc_python_dependencies//:requirements.bzl", "requirement") +load("@rules_proto//proto:defs.bzl", "proto_library") load("//bazel:python_rules.bzl", "py_grpc_library", "py_proto_library") package(default_testonly = 1) diff --git a/examples/python/multiprocessing/BUILD b/examples/python/multiprocessing/BUILD index f51e235caaa..ea9b6a3ec6f 100644 --- a/examples/python/multiprocessing/BUILD +++ b/examples/python/multiprocessing/BUILD @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@rules_proto//proto:defs.bzl", "proto_library") load("//bazel:python_rules.bzl", "py_grpc_library", "py_proto_library") proto_library( diff --git a/src/core/ext/filters/client_channel/lb_policy/xds/cds.cc b/src/core/ext/filters/client_channel/lb_policy/xds/cds.cc index 4b39d88bc9d..5344cb3ddb1 100644 --- a/src/core/ext/filters/client_channel/lb_policy/xds/cds.cc +++ b/src/core/ext/filters/client_channel/lb_policy/xds/cds.cc @@ -63,7 +63,7 @@ class CdsLb : public LoadBalancingPolicy { public: explicit ClusterWatcher(RefCountedPtr parent) : parent_(std::move(parent)) {} - void OnClusterChanged(CdsUpdate cluster_data) override; + void OnClusterChanged(XdsApi::CdsUpdate cluster_data) override; void OnError(grpc_error* error) override; private: @@ -111,7 +111,7 @@ class CdsLb : public LoadBalancingPolicy { // CdsLb::ClusterWatcher // -void CdsLb::ClusterWatcher::OnClusterChanged(CdsUpdate cluster_data) { +void CdsLb::ClusterWatcher::OnClusterChanged(XdsApi::CdsUpdate cluster_data) { if (GRPC_TRACE_FLAG_ENABLED(grpc_cds_lb_trace)) { gpr_log(GPR_INFO, "[cdslb %p] received CDS update from xds client", parent_.get()); diff --git a/src/core/ext/filters/client_channel/lb_policy/xds/xds.cc b/src/core/ext/filters/client_channel/lb_policy/xds/xds.cc index 1c8c46eaace..5cd24724f6e 100644 --- a/src/core/ext/filters/client_channel/lb_policy/xds/xds.cc +++ b/src/core/ext/filters/client_channel/lb_policy/xds/xds.cc @@ -172,7 +172,7 @@ class XdsLb : public LoadBalancingPolicy { RefCountedPtr xds_policy_; PickerList pickers_; - RefCountedPtr drop_config_; + RefCountedPtr drop_config_; }; class FallbackHelper : public ChannelControlHelper { @@ -286,7 +286,7 @@ class XdsLb : public LoadBalancingPolicy { ~LocalityMap() { xds_policy_.reset(DEBUG_LOCATION, "LocalityMap"); } void UpdateLocked( - const XdsPriorityListUpdate::LocalityMap& locality_map_update); + const XdsApi::PriorityListUpdate::LocalityMap& locality_map_update); void ResetBackoffLocked(); void UpdateXdsPickerLocked(); OrphanablePtr ExtractLocalityLocked( @@ -316,10 +316,10 @@ class XdsLb : public LoadBalancingPolicy { static void OnDelayedRemovalTimerLocked(void* arg, grpc_error* error); static void OnFailoverTimerLocked(void* arg, grpc_error* error); - const XdsPriorityListUpdate& priority_list_update() const { + const XdsApi::PriorityListUpdate& priority_list_update() const { return xds_policy_->priority_list_update_; } - const XdsPriorityListUpdate::LocalityMap* locality_map_update() const { + const XdsApi::PriorityListUpdate::LocalityMap* locality_map_update() const { return xds_policy_->priority_list_update_.Find(priority_); } @@ -431,10 +431,10 @@ class XdsLb : public LoadBalancingPolicy { // The priority that is being used. uint32_t current_priority_ = UINT32_MAX; // The update for priority_list_. - XdsPriorityListUpdate priority_list_update_; + XdsApi::PriorityListUpdate priority_list_update_; // The config for dropping calls. - RefCountedPtr drop_config_; + RefCountedPtr drop_config_; // The stats for client-side load reporting. XdsClientStats client_stats_; @@ -594,7 +594,7 @@ class XdsLb::EndpointWatcher : public XdsClient::EndpointWatcherInterface { ~EndpointWatcher() { xds_policy_.reset(DEBUG_LOCATION, "EndpointWatcher"); } - void OnEndpointChanged(EdsUpdate update) override { + void OnEndpointChanged(XdsApi::EdsUpdate update) override { if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_xds_trace)) { gpr_log(GPR_INFO, "[xdslb %p] Received EDS update from xds client", xds_policy_.get()); @@ -1032,6 +1032,8 @@ void XdsLb::UpdatePrioritiesLocked() { for (uint32_t priority = 0; priority < priorities_.size(); ++priority) { LocalityMap* locality_map = priorities_[priority].get(); const auto* locality_map_update = priority_list_update_.Find(priority); + // If we have more current priorities than exist in the update, stop here. + if (locality_map_update == nullptr) break; // Propagate locality_map_update. // TODO(juanlishen): Find a clean way to skip duplicate update for a // priority. @@ -1154,7 +1156,7 @@ XdsLb::LocalityMap::LocalityMap(RefCountedPtr xds_policy, } void XdsLb::LocalityMap::UpdateLocked( - const XdsPriorityListUpdate::LocalityMap& locality_map_update) { + const XdsApi::PriorityListUpdate::LocalityMap& locality_map_update) { if (xds_policy_->shutting_down_) return; if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_xds_trace)) { gpr_log(GPR_INFO, "[xdslb %p] Start Updating priority %" PRIu32, diff --git a/src/core/ext/filters/client_channel/xds/xds_api.cc b/src/core/ext/filters/client_channel/xds/xds_api.cc index a51256594fb..d41c25e2d26 100644 --- a/src/core/ext/filters/client_channel/xds/xds_api.cc +++ b/src/core/ext/filters/client_channel/xds/xds_api.cc @@ -56,8 +56,12 @@ namespace grpc_core { -bool XdsPriorityListUpdate::operator==( - const XdsPriorityListUpdate& other) const { +// +// XdsApi::PriorityListUpdate +// + +bool XdsApi::PriorityListUpdate::operator==( + const XdsApi::PriorityListUpdate& other) const { if (priorities_.size() != other.priorities_.size()) return false; for (size_t i = 0; i < priorities_.size(); ++i) { if (priorities_[i].localities != other.priorities_[i].localities) { @@ -67,8 +71,8 @@ bool XdsPriorityListUpdate::operator==( return true; } -void XdsPriorityListUpdate::Add( - XdsPriorityListUpdate::LocalityMap::Locality locality) { +void XdsApi::PriorityListUpdate::Add( + XdsApi::PriorityListUpdate::LocalityMap::Locality locality) { // Pad the missing priorities in case the localities are not ordered by // priority. if (!Contains(locality.priority)) priorities_.resize(locality.priority + 1); @@ -76,13 +80,13 @@ void XdsPriorityListUpdate::Add( locality_map.localities.emplace(locality.name, std::move(locality)); } -const XdsPriorityListUpdate::LocalityMap* XdsPriorityListUpdate::Find( +const XdsApi::PriorityListUpdate::LocalityMap* XdsApi::PriorityListUpdate::Find( uint32_t priority) const { if (!Contains(priority)) return nullptr; return &priorities_[priority]; } -bool XdsPriorityListUpdate::Contains( +bool XdsApi::PriorityListUpdate::Contains( const RefCountedPtr& name) { for (size_t i = 0; i < priorities_.size(); ++i) { const LocalityMap& locality_map = priorities_[i]; @@ -91,7 +95,11 @@ bool XdsPriorityListUpdate::Contains( return false; } -bool XdsDropConfig::ShouldDrop(const std::string** category_name) const { +// +// XdsApi::DropConfig +// + +bool XdsApi::DropConfig::ShouldDrop(const std::string** category_name) const { for (size_t i = 0; i < drop_category_list_.size(); ++i) { const auto& drop_category = drop_category_list_[i]; // Generate a random number in [0, 1000000). @@ -104,6 +112,17 @@ bool XdsDropConfig::ShouldDrop(const std::string** category_name) const { return false; } +// +// XdsApi +// + +const char* XdsApi::kLdsTypeUrl = "type.googleapis.com/envoy.api.v2.Listener"; +const char* XdsApi::kRdsTypeUrl = + "type.googleapis.com/envoy.api.v2.RouteConfiguration"; +const char* XdsApi::kCdsTypeUrl = "type.googleapis.com/envoy.api.v2.Cluster"; +const char* XdsApi::kEdsTypeUrl = + "type.googleapis.com/envoy.api.v2.ClusterLoadAssignment"; + namespace { void PopulateMetadataValue(upb_arena* arena, google_protobuf_Value* value_pb, @@ -203,67 +222,21 @@ void PopulateNode(upb_arena* arena, const XdsBootstrap::Node* node, upb_strview_makez(build_version)); } -} // namespace - -grpc_slice XdsUnsupportedTypeNackRequestCreateAndEncode( - const std::string& type_url, const std::string& nonce, grpc_error* error) { - upb::Arena arena; +envoy_api_v2_DiscoveryRequest* CreateDiscoveryRequest( + upb_arena* arena, const char* type_url, const std::string& version, + const std::string& nonce, grpc_error* error, const XdsBootstrap::Node* node, + const char* build_version) { // Create a request. envoy_api_v2_DiscoveryRequest* request = - envoy_api_v2_DiscoveryRequest_new(arena.ptr()); + envoy_api_v2_DiscoveryRequest_new(arena); // Set type_url. - envoy_api_v2_DiscoveryRequest_set_type_url( - request, upb_strview_makez(type_url.c_str())); - // Set nonce. - envoy_api_v2_DiscoveryRequest_set_response_nonce( - request, upb_strview_makez(nonce.c_str())); - // Set error_detail. - grpc_slice error_description_slice; - GPR_ASSERT(grpc_error_get_str(error, GRPC_ERROR_STR_DESCRIPTION, - &error_description_slice)); - upb_strview error_description_strview = - upb_strview_make(reinterpret_cast( - GPR_SLICE_START_PTR(error_description_slice)), - GPR_SLICE_LENGTH(error_description_slice)); - google_rpc_Status* error_detail = - envoy_api_v2_DiscoveryRequest_mutable_error_detail(request, arena.ptr()); - google_rpc_Status_set_message(error_detail, error_description_strview); - GRPC_ERROR_UNREF(error); - // Encode the request. - size_t output_length; - char* output = envoy_api_v2_DiscoveryRequest_serialize(request, arena.ptr(), - &output_length); - return grpc_slice_from_copied_buffer(output, output_length); -} - -grpc_slice XdsLdsRequestCreateAndEncode(const std::string& server_name, - const XdsBootstrap::Node* node, - const char* build_version, - const std::string& version, - const std::string& nonce, - grpc_error* error) { - upb::Arena arena; - // Create a request. - envoy_api_v2_DiscoveryRequest* request = - envoy_api_v2_DiscoveryRequest_new(arena.ptr()); + envoy_api_v2_DiscoveryRequest_set_type_url(request, + upb_strview_makez(type_url)); // Set version_info. if (!version.empty()) { envoy_api_v2_DiscoveryRequest_set_version_info( request, upb_strview_makez(version.c_str())); } - // Populate node. - if (build_version != nullptr) { - envoy_api_v2_core_Node* node_msg = - envoy_api_v2_DiscoveryRequest_mutable_node(request, arena.ptr()); - PopulateNode(arena.ptr(), node, build_version, node_msg); - } - // Add resource_name. - envoy_api_v2_DiscoveryRequest_add_resource_names( - request, upb_strview_make(server_name.data(), server_name.size()), - arena.ptr()); - // Set type_url. - envoy_api_v2_DiscoveryRequest_set_type_url(request, - upb_strview_makez(kLdsTypeUrl)); // Set nonce. if (!nonce.empty()) { envoy_api_v2_DiscoveryRequest_set_response_nonce( @@ -279,148 +252,98 @@ grpc_slice XdsLdsRequestCreateAndEncode(const std::string& server_name, GPR_SLICE_START_PTR(error_description_slice)), GPR_SLICE_LENGTH(error_description_slice)); google_rpc_Status* error_detail = - envoy_api_v2_DiscoveryRequest_mutable_error_detail(request, - arena.ptr()); + envoy_api_v2_DiscoveryRequest_mutable_error_detail(request, arena); google_rpc_Status_set_message(error_detail, error_description_strview); GRPC_ERROR_UNREF(error); } - // Encode the request. + // Populate node. + if (build_version != nullptr) { + envoy_api_v2_core_Node* node_msg = + envoy_api_v2_DiscoveryRequest_mutable_node(request, arena); + PopulateNode(arena, node, build_version, node_msg); + } + return request; +} + +grpc_slice SerializeDiscoveryRequest(upb_arena* arena, + envoy_api_v2_DiscoveryRequest* request) { size_t output_length; - char* output = envoy_api_v2_DiscoveryRequest_serialize(request, arena.ptr(), - &output_length); + char* output = + envoy_api_v2_DiscoveryRequest_serialize(request, arena, &output_length); return grpc_slice_from_copied_buffer(output, output_length); } -grpc_slice XdsRdsRequestCreateAndEncode(const std::string& route_config_name, - const XdsBootstrap::Node* node, - const char* build_version, - const std::string& version, - const std::string& nonce, - grpc_error* error) { +} // namespace + +grpc_slice XdsApi::CreateUnsupportedTypeNackRequest(const std::string& type_url, + const std::string& nonce, + grpc_error* error) { + upb::Arena arena; + envoy_api_v2_DiscoveryRequest* request = CreateDiscoveryRequest( + arena.ptr(), type_url.c_str(), /*version=*/"", nonce, error, + /*node=*/nullptr, /*build_version=*/nullptr); + return SerializeDiscoveryRequest(arena.ptr(), request); +} + +grpc_slice XdsApi::CreateLdsRequest(const std::string& server_name, + const std::string& version, + const std::string& nonce, grpc_error* error, + bool populate_node) { upb::Arena arena; - // Create a request. envoy_api_v2_DiscoveryRequest* request = - envoy_api_v2_DiscoveryRequest_new(arena.ptr()); - // Set version_info. - if (!version.empty()) { - envoy_api_v2_DiscoveryRequest_set_version_info( - request, upb_strview_makez(version.c_str())); - } - // Populate node. - if (build_version != nullptr) { - envoy_api_v2_core_Node* node_msg = - envoy_api_v2_DiscoveryRequest_mutable_node(request, arena.ptr()); - PopulateNode(arena.ptr(), node, build_version, node_msg); - } + CreateDiscoveryRequest(arena.ptr(), kLdsTypeUrl, version, nonce, error, + populate_node ? node_ : nullptr, + populate_node ? build_version_ : nullptr); + // Add resource_name. + envoy_api_v2_DiscoveryRequest_add_resource_names( + request, upb_strview_make(server_name.data(), server_name.size()), + arena.ptr()); + return SerializeDiscoveryRequest(arena.ptr(), request); +} + +grpc_slice XdsApi::CreateRdsRequest(const std::string& route_config_name, + const std::string& version, + const std::string& nonce, grpc_error* error, + bool populate_node) { + upb::Arena arena; + envoy_api_v2_DiscoveryRequest* request = + CreateDiscoveryRequest(arena.ptr(), kRdsTypeUrl, version, nonce, error, + populate_node ? node_ : nullptr, + populate_node ? build_version_ : nullptr); // Add resource_name. envoy_api_v2_DiscoveryRequest_add_resource_names( request, upb_strview_make(route_config_name.data(), route_config_name.size()), arena.ptr()); - // Set type_url. - envoy_api_v2_DiscoveryRequest_set_type_url(request, - upb_strview_makez(kRdsTypeUrl)); - // Set nonce. - if (!nonce.empty()) { - envoy_api_v2_DiscoveryRequest_set_response_nonce( - request, upb_strview_makez(nonce.c_str())); - } - // Set error_detail if it's a NACK. - if (error != GRPC_ERROR_NONE) { - grpc_slice error_description_slice; - GPR_ASSERT(grpc_error_get_str(error, GRPC_ERROR_STR_DESCRIPTION, - &error_description_slice)); - upb_strview error_description_strview = - upb_strview_make(reinterpret_cast( - GPR_SLICE_START_PTR(error_description_slice)), - GPR_SLICE_LENGTH(error_description_slice)); - google_rpc_Status* error_detail = - envoy_api_v2_DiscoveryRequest_mutable_error_detail(request, - arena.ptr()); - google_rpc_Status_set_message(error_detail, error_description_strview); - GRPC_ERROR_UNREF(error); - } - // Encode the request. - size_t output_length; - char* output = envoy_api_v2_DiscoveryRequest_serialize(request, arena.ptr(), - &output_length); - return grpc_slice_from_copied_buffer(output, output_length); + return SerializeDiscoveryRequest(arena.ptr(), request); } -grpc_slice XdsCdsRequestCreateAndEncode( - const std::set& cluster_names, const XdsBootstrap::Node* node, - const char* build_version, const std::string& version, - const std::string& nonce, grpc_error* error) { +grpc_slice XdsApi::CreateCdsRequest(const std::set& cluster_names, + const std::string& version, + const std::string& nonce, grpc_error* error, + bool populate_node) { upb::Arena arena; - // Create a request. envoy_api_v2_DiscoveryRequest* request = - envoy_api_v2_DiscoveryRequest_new(arena.ptr()); - // Set version_info. - if (!version.empty()) { - envoy_api_v2_DiscoveryRequest_set_version_info( - request, upb_strview_makez(version.c_str())); - } - // Populate node. - if (build_version != nullptr) { - envoy_api_v2_core_Node* node_msg = - envoy_api_v2_DiscoveryRequest_mutable_node(request, arena.ptr()); - PopulateNode(arena.ptr(), node, build_version, node_msg); - } + CreateDiscoveryRequest(arena.ptr(), kCdsTypeUrl, version, nonce, error, + populate_node ? node_ : nullptr, + populate_node ? build_version_ : nullptr); // Add resource_names. for (const auto& cluster_name : cluster_names) { envoy_api_v2_DiscoveryRequest_add_resource_names( request, upb_strview_make(cluster_name.data(), cluster_name.size()), arena.ptr()); } - // Set type_url. - envoy_api_v2_DiscoveryRequest_set_type_url(request, - upb_strview_makez(kCdsTypeUrl)); - // Set nonce. - if (!nonce.empty()) { - envoy_api_v2_DiscoveryRequest_set_response_nonce( - request, upb_strview_makez(nonce.c_str())); - } - // Set error_detail if it's a NACK. - if (error != GRPC_ERROR_NONE) { - grpc_slice error_description_slice; - GPR_ASSERT(grpc_error_get_str(error, GRPC_ERROR_STR_DESCRIPTION, - &error_description_slice)); - upb_strview error_description_strview = - upb_strview_make(reinterpret_cast( - GPR_SLICE_START_PTR(error_description_slice)), - GPR_SLICE_LENGTH(error_description_slice)); - google_rpc_Status* error_detail = - envoy_api_v2_DiscoveryRequest_mutable_error_detail(request, - arena.ptr()); - google_rpc_Status_set_message(error_detail, error_description_strview); - GRPC_ERROR_UNREF(error); - } - // Encode the request. - size_t output_length; - char* output = envoy_api_v2_DiscoveryRequest_serialize(request, arena.ptr(), - &output_length); - return grpc_slice_from_copied_buffer(output, output_length); + return SerializeDiscoveryRequest(arena.ptr(), request); } -grpc_slice XdsEdsRequestCreateAndEncode( - const std::set& eds_service_names, - const XdsBootstrap::Node* node, const char* build_version, - const std::string& version, const std::string& nonce, grpc_error* error) { +grpc_slice XdsApi::CreateEdsRequest( + const std::set& eds_service_names, const std::string& version, + const std::string& nonce, grpc_error* error, bool populate_node) { upb::Arena arena; - // Create a request. envoy_api_v2_DiscoveryRequest* request = - envoy_api_v2_DiscoveryRequest_new(arena.ptr()); - // Set version_info. - if (!version.empty()) { - envoy_api_v2_DiscoveryRequest_set_version_info( - request, upb_strview_makez(version.c_str())); - } - // Populate node. - if (build_version != nullptr) { - envoy_api_v2_core_Node* node_msg = - envoy_api_v2_DiscoveryRequest_mutable_node(request, arena.ptr()); - PopulateNode(arena.ptr(), node, build_version, node_msg); - } + CreateDiscoveryRequest(arena.ptr(), kEdsTypeUrl, version, nonce, error, + populate_node ? node_ : nullptr, + populate_node ? build_version_ : nullptr); // Add resource_names. for (const auto& eds_service_name : eds_service_names) { envoy_api_v2_DiscoveryRequest_add_resource_names( @@ -428,34 +351,7 @@ grpc_slice XdsEdsRequestCreateAndEncode( upb_strview_make(eds_service_name.data(), eds_service_name.size()), arena.ptr()); } - // Set type_url. - envoy_api_v2_DiscoveryRequest_set_type_url(request, - upb_strview_makez(kEdsTypeUrl)); - // Set nonce. - if (!nonce.empty()) { - envoy_api_v2_DiscoveryRequest_set_response_nonce( - request, upb_strview_makez(nonce.c_str())); - } - // Set error_detail if it's a NACK. - if (error != GRPC_ERROR_NONE) { - grpc_slice error_description_slice; - GPR_ASSERT(grpc_error_get_str(error, GRPC_ERROR_STR_DESCRIPTION, - &error_description_slice)); - upb_strview error_description_strview = - upb_strview_make(reinterpret_cast( - GPR_SLICE_START_PTR(error_description_slice)), - GPR_SLICE_LENGTH(error_description_slice)); - google_rpc_Status* error_detail = - envoy_api_v2_DiscoveryRequest_mutable_error_detail(request, - arena.ptr()); - google_rpc_Status_set_message(error_detail, error_description_strview); - GRPC_ERROR_UNREF(error); - } - // Encode the request. - size_t output_length; - char* output = envoy_api_v2_DiscoveryRequest_serialize(request, arena.ptr(), - &output_length); - return grpc_slice_from_copied_buffer(output, output_length); + return SerializeDiscoveryRequest(arena.ptr(), request); } namespace { @@ -511,7 +407,7 @@ MatchType DomainPatternMatchType(const std::string& domain_pattern) { grpc_error* RouteConfigParse( const envoy_api_v2_RouteConfiguration* route_config, - const std::string& expected_server_name, RdsUpdate* rds_update) { + const std::string& expected_server_name, XdsApi::RdsUpdate* rds_update) { // Strip off port from server name, if any. size_t pos = expected_server_name.find(':'); std::string expected_host_name = expected_server_name.substr(0, pos); @@ -604,11 +500,9 @@ grpc_error* RouteConfigParse( return GRPC_ERROR_NONE; } -} // namespace - grpc_error* LdsResponseParse(const envoy_api_v2_DiscoveryResponse* response, const std::string& expected_server_name, - LdsUpdate* lds_update, upb_arena* arena) { + XdsApi::LdsUpdate* lds_update, upb_arena* arena) { // Get the resources from the response. size_t size; const google_protobuf_Any* const* resources = @@ -620,7 +514,7 @@ grpc_error* LdsResponseParse(const envoy_api_v2_DiscoveryResponse* response, for (size_t i = 0; i < size; ++i) { // Check the type_url of the resource. const upb_strview type_url = google_protobuf_Any_type_url(resources[i]); - if (!upb_strview_eql(type_url, upb_strview_makez(kLdsTypeUrl))) { + if (!upb_strview_eql(type_url, upb_strview_makez(XdsApi::kLdsTypeUrl))) { return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Resource is not LDS."); } // Decode the listener. @@ -655,7 +549,7 @@ grpc_error* LdsResponseParse(const envoy_api_v2_DiscoveryResponse* response, const envoy_api_v2_RouteConfiguration* route_config = envoy_config_filter_network_http_connection_manager_v2_HttpConnectionManager_route_config( http_connection_manager); - RdsUpdate rds_update; + XdsApi::RdsUpdate rds_update; grpc_error* error = RouteConfigParse(route_config, expected_server_name, &rds_update); if (error != GRPC_ERROR_NONE) return error; @@ -690,7 +584,7 @@ grpc_error* LdsResponseParse(const envoy_api_v2_DiscoveryResponse* response, grpc_error* RdsResponseParse(const envoy_api_v2_DiscoveryResponse* response, const std::string& expected_server_name, const std::string& expected_route_config_name, - RdsUpdate* rds_update, upb_arena* arena) { + XdsApi::RdsUpdate* rds_update, upb_arena* arena) { // Get the resources from the response. size_t size; const google_protobuf_Any* const* resources = @@ -702,7 +596,7 @@ grpc_error* RdsResponseParse(const envoy_api_v2_DiscoveryResponse* response, for (size_t i = 0; i < size; ++i) { // Check the type_url of the resource. const upb_strview type_url = google_protobuf_Any_type_url(resources[i]); - if (!upb_strview_eql(type_url, upb_strview_makez(kRdsTypeUrl))) { + if (!upb_strview_eql(type_url, upb_strview_makez(XdsApi::kRdsTypeUrl))) { return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Resource is not RDS."); } // Decode the route_config. @@ -720,7 +614,7 @@ grpc_error* RdsResponseParse(const envoy_api_v2_DiscoveryResponse* response, upb_strview_makez(expected_route_config_name.c_str()); if (!upb_strview_eql(name, expected_name)) continue; // Parse the route_config. - RdsUpdate local_rds_update; + XdsApi::RdsUpdate local_rds_update; grpc_error* error = RouteConfigParse(route_config, expected_server_name, &local_rds_update); if (error != GRPC_ERROR_NONE) return error; @@ -732,7 +626,8 @@ grpc_error* RdsResponseParse(const envoy_api_v2_DiscoveryResponse* response, } grpc_error* CdsResponseParse(const envoy_api_v2_DiscoveryResponse* response, - CdsUpdateMap* cds_update_map, upb_arena* arena) { + XdsApi::CdsUpdateMap* cds_update_map, + upb_arena* arena) { // Get the resources from the response. size_t size; const google_protobuf_Any* const* resources = @@ -743,10 +638,10 @@ grpc_error* CdsResponseParse(const envoy_api_v2_DiscoveryResponse* response, } // Parse all the resources in the CDS response. for (size_t i = 0; i < size; ++i) { - CdsUpdate cds_update; + XdsApi::CdsUpdate cds_update; // Check the type_url of the resource. const upb_strview type_url = google_protobuf_Any_type_url(resources[i]); - if (!upb_strview_eql(type_url, upb_strview_makez(kCdsTypeUrl))) { + if (!upb_strview_eql(type_url, upb_strview_makez(XdsApi::kCdsTypeUrl))) { return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Resource is not CDS."); } // Decode the cluster. @@ -801,8 +696,6 @@ grpc_error* CdsResponseParse(const envoy_api_v2_DiscoveryResponse* response, return GRPC_ERROR_NONE; } -namespace { - grpc_error* ServerAddressParseAndAppend( const envoy_api_v2_endpoint_LbEndpoint* lb_endpoint, ServerAddressList* list) { @@ -840,7 +733,7 @@ grpc_error* ServerAddressParseAndAppend( grpc_error* LocalityParse( const envoy_api_v2_endpoint_LocalityLbEndpoints* locality_lb_endpoints, - XdsPriorityListUpdate::LocalityMap::Locality* output_locality) { + XdsApi::PriorityListUpdate::LocalityMap::Locality* output_locality) { // Parse LB weight. const google_protobuf_UInt32Value* lb_weight = envoy_api_v2_endpoint_LocalityLbEndpoints_load_balancing_weight( @@ -878,7 +771,7 @@ grpc_error* LocalityParse( grpc_error* DropParseAndAppend( const envoy_api_v2_ClusterLoadAssignment_Policy_DropOverload* drop_overload, - XdsDropConfig* drop_config, bool* drop_all) { + XdsApi::DropConfig* drop_config, bool* drop_all) { // Get the category. upb_strview category = envoy_api_v2_ClusterLoadAssignment_Policy_DropOverload_category( @@ -918,7 +811,7 @@ grpc_error* DropParseAndAppend( grpc_error* EdsResponsedParse( const envoy_api_v2_DiscoveryResponse* response, const std::set& expected_eds_service_names, - EdsUpdateMap* eds_update_map, upb_arena* arena) { + XdsApi::EdsUpdateMap* eds_update_map, upb_arena* arena) { // Get the resources from the response. size_t size; const google_protobuf_Any* const* resources = @@ -928,10 +821,10 @@ grpc_error* EdsResponsedParse( "EDS response contains 0 resource."); } for (size_t i = 0; i < size; ++i) { - EdsUpdate eds_update; + XdsApi::EdsUpdate eds_update; // Check the type_url of the resource. upb_strview type_url = google_protobuf_Any_type_url(resources[i]); - if (!upb_strview_eql(type_url, upb_strview_makez(kEdsTypeUrl))) { + if (!upb_strview_eql(type_url, upb_strview_makez(XdsApi::kEdsTypeUrl))) { return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Resource is not EDS."); } // Get the cluster_load_assignment. @@ -960,7 +853,7 @@ grpc_error* EdsResponsedParse( envoy_api_v2_ClusterLoadAssignment_endpoints(cluster_load_assignment, &locality_size); for (size_t j = 0; j < locality_size; ++j) { - XdsPriorityListUpdate::LocalityMap::Locality locality; + XdsApi::PriorityListUpdate::LocalityMap::Locality locality; grpc_error* error = LocalityParse(endpoints[j], &locality); if (error != GRPC_ERROR_NONE) return error; // Filter out locality with weight 0. @@ -968,7 +861,7 @@ grpc_error* EdsResponsedParse( eds_update.priority_list_update.Add(locality); } // Get the drop config. - eds_update.drop_config = MakeRefCounted(); + eds_update.drop_config = MakeRefCounted(); const envoy_api_v2_ClusterLoadAssignment_Policy* policy = envoy_api_v2_ClusterLoadAssignment_policy(cluster_load_assignment); if (policy != nullptr) { @@ -998,7 +891,7 @@ grpc_error* EdsResponsedParse( } // namespace -grpc_error* XdsAdsResponseDecodeAndParse( +grpc_error* XdsApi::ParseAdsResponse( const grpc_slice& encoded_response, const std::string& expected_server_name, const std::string& expected_route_config_name, const std::set& expected_eds_service_names, @@ -1047,7 +940,7 @@ grpc_error* XdsAdsResponseDecodeAndParse( namespace { -grpc_slice LrsRequestEncode( +grpc_slice SerializeLrsRequest( const envoy_service_load_stats_v2_LoadStatsRequest* request, upb_arena* arena) { size_t output_length; @@ -1058,9 +951,7 @@ grpc_slice LrsRequestEncode( } // namespace -grpc_slice XdsLrsRequestCreateAndEncode(const std::string& server_name, - const XdsBootstrap::Node* node, - const char* build_version) { +grpc_slice XdsApi::CreateLrsInitialRequest(const std::string& server_name) { upb::Arena arena; // Create a request. envoy_service_load_stats_v2_LoadStatsRequest* request = @@ -1069,7 +960,7 @@ grpc_slice XdsLrsRequestCreateAndEncode(const std::string& server_name, envoy_api_v2_core_Node* node_msg = envoy_service_load_stats_v2_LoadStatsRequest_mutable_node(request, arena.ptr()); - PopulateNode(arena.ptr(), node, build_version, node_msg); + PopulateNode(arena.ptr(), node_, build_version_, node_msg); // Add cluster stats. There is only one because we only use one server name in // one channel. envoy_api_v2_endpoint_ClusterStats* cluster_stats = @@ -1078,7 +969,7 @@ grpc_slice XdsLrsRequestCreateAndEncode(const std::string& server_name, // Set the cluster name. envoy_api_v2_endpoint_ClusterStats_set_cluster_name( cluster_stats, upb_strview_makez(server_name.c_str())); - return LrsRequestEncode(request, arena.ptr()); + return SerializeLrsRequest(request, arena.ptr()); } namespace { @@ -1123,7 +1014,7 @@ void LocalityStatsPopulate( } // namespace -grpc_slice XdsLrsRequestCreateAndEncode( +grpc_slice XdsApi::CreateLrsRequest( std::map, StringLess> client_stats_map) { upb::Arena arena; @@ -1193,12 +1084,12 @@ grpc_slice XdsLrsRequestCreateAndEncode( timespec.tv_nsec); } } - return LrsRequestEncode(request, arena.ptr()); + return SerializeLrsRequest(request, arena.ptr()); } -grpc_error* XdsLrsResponseDecodeAndParse(const grpc_slice& encoded_response, - std::set* cluster_names, - grpc_millis* load_reporting_interval) { +grpc_error* XdsApi::ParseLrsResponse(const grpc_slice& encoded_response, + std::set* cluster_names, + grpc_millis* load_reporting_interval) { upb::Arena arena; // Decode the response. const envoy_service_load_stats_v2_LoadStatsResponse* decoded_response = diff --git a/src/core/ext/filters/client_channel/xds/xds_api.h b/src/core/ext/filters/client_channel/xds/xds_api.h index f98d32b1e38..56d1e1cbc63 100644 --- a/src/core/ext/filters/client_channel/xds/xds_api.h +++ b/src/core/ext/filters/client_channel/xds/xds_api.h @@ -34,215 +34,218 @@ namespace grpc_core { -constexpr char kLdsTypeUrl[] = "type.googleapis.com/envoy.api.v2.Listener"; -constexpr char kRdsTypeUrl[] = - "type.googleapis.com/envoy.api.v2.RouteConfiguration"; -constexpr char kCdsTypeUrl[] = "type.googleapis.com/envoy.api.v2.Cluster"; -constexpr char kEdsTypeUrl[] = - "type.googleapis.com/envoy.api.v2.ClusterLoadAssignment"; - -struct RdsUpdate { - // The name to use in the CDS request. - std::string cluster_name; -}; - -struct LdsUpdate { - // The name to use in the RDS request. - std::string route_config_name; - // The name to use in the CDS request. Present if the LDS response has it - // inlined. - Optional rds_update; -}; +class XdsApi { + public: + static const char* kLdsTypeUrl; + static const char* kRdsTypeUrl; + static const char* kCdsTypeUrl; + static const char* kEdsTypeUrl; + + struct RdsUpdate { + // The name to use in the CDS request. + std::string cluster_name; + }; -using LdsUpdateMap = std::map; + struct LdsUpdate { + // The name to use in the RDS request. + std::string route_config_name; + // The name to use in the CDS request. Present if the LDS response has it + // inlined. + Optional rds_update; + }; -using RdsUpdateMap = std::map; + using LdsUpdateMap = std::map; -struct CdsUpdate { - // The name to use in the EDS request. - // If empty, the cluster name will be used. - std::string eds_service_name; - // The LRS server to use for load reporting. - // If not set, load reporting will be disabled. - // If set to the empty string, will use the same server we obtained the CDS - // data from. - Optional lrs_load_reporting_server_name; -}; + using RdsUpdateMap = std::map; -using CdsUpdateMap = std::map; + struct CdsUpdate { + // The name to use in the EDS request. + // If empty, the cluster name will be used. + std::string eds_service_name; + // The LRS server to use for load reporting. + // If not set, load reporting will be disabled. + // If set to the empty string, will use the same server we obtained the CDS + // data from. + Optional lrs_load_reporting_server_name; + }; -class XdsPriorityListUpdate { - public: - struct LocalityMap { - struct Locality { - bool operator==(const Locality& other) const { - return *name == *other.name && serverlist == other.serverlist && - lb_weight == other.lb_weight && priority == other.priority; - } + using CdsUpdateMap = std::map; - // This comparator only compares the locality names. - struct Less { - bool operator()(const Locality& lhs, const Locality& rhs) const { - return XdsLocalityName::Less()(lhs.name, rhs.name); + class PriorityListUpdate { + public: + struct LocalityMap { + struct Locality { + bool operator==(const Locality& other) const { + return *name == *other.name && serverlist == other.serverlist && + lb_weight == other.lb_weight && priority == other.priority; } + + // This comparator only compares the locality names. + struct Less { + bool operator()(const Locality& lhs, const Locality& rhs) const { + return XdsLocalityName::Less()(lhs.name, rhs.name); + } + }; + + RefCountedPtr name; + ServerAddressList serverlist; + uint32_t lb_weight; + uint32_t priority; }; - RefCountedPtr name; - ServerAddressList serverlist; - uint32_t lb_weight; - uint32_t priority; + bool Contains(const RefCountedPtr& name) const { + return localities.find(name) != localities.end(); + } + + size_t size() const { return localities.size(); } + + std::map, Locality, XdsLocalityName::Less> + localities; }; - bool Contains(const RefCountedPtr& name) const { - return localities.find(name) != localities.end(); + bool operator==(const PriorityListUpdate& other) const; + bool operator!=(const PriorityListUpdate& other) const { + return !(*this == other); } - size_t size() const { return localities.size(); } + void Add(LocalityMap::Locality locality); - std::map, Locality, XdsLocalityName::Less> - localities; - }; + const LocalityMap* Find(uint32_t priority) const; - bool operator==(const XdsPriorityListUpdate& other) const; - bool operator!=(const XdsPriorityListUpdate& other) const { - return !(*this == other); - } + bool Contains(uint32_t priority) const { + return priority < priorities_.size(); + } + bool Contains(const RefCountedPtr& name); - void Add(LocalityMap::Locality locality); + bool empty() const { return priorities_.empty(); } + size_t size() const { return priorities_.size(); } - const LocalityMap* Find(uint32_t priority) const; + // Callers should make sure the priority list is non-empty. + uint32_t LowestPriority() const { + return static_cast(priorities_.size()) - 1; + } - bool Contains(uint32_t priority) const { - return priority < priorities_.size(); - } - bool Contains(const RefCountedPtr& name); + private: + InlinedVector priorities_; + }; - bool empty() const { return priorities_.empty(); } - size_t size() const { return priorities_.size(); } + // There are two phases of accessing this class's content: + // 1. to initialize in the control plane combiner; + // 2. to use in the data plane combiner. + // So no additional synchronization is needed. + class DropConfig : public RefCounted { + public: + struct DropCategory { + bool operator==(const DropCategory& other) const { + return name == other.name && + parts_per_million == other.parts_per_million; + } - // Callers should make sure the priority list is non-empty. - uint32_t LowestPriority() const { - return static_cast(priorities_.size()) - 1; - } + std::string name; + const uint32_t parts_per_million; + }; - private: - InlinedVector priorities_; -}; + using DropCategoryList = InlinedVector; -// There are two phases of accessing this class's content: -// 1. to initialize in the control plane combiner; -// 2. to use in the data plane combiner. -// So no additional synchronization is needed. -class XdsDropConfig : public RefCounted { - public: - struct DropCategory { - bool operator==(const DropCategory& other) const { - return name == other.name && parts_per_million == other.parts_per_million; + void AddCategory(std::string name, uint32_t parts_per_million) { + drop_category_list_.emplace_back( + DropCategory{std::move(name), parts_per_million}); } - std::string name; - const uint32_t parts_per_million; - }; + // The only method invoked from the data plane combiner. + bool ShouldDrop(const std::string** category_name) const; - using DropCategoryList = InlinedVector; + const DropCategoryList& drop_category_list() const { + return drop_category_list_; + } - void AddCategory(std::string name, uint32_t parts_per_million) { - drop_category_list_.emplace_back( - DropCategory{std::move(name), parts_per_million}); - } + bool operator==(const DropConfig& other) const { + return drop_category_list_ == other.drop_category_list_; + } + bool operator!=(const DropConfig& other) const { return !(*this == other); } - // The only method invoked from the data plane combiner. - bool ShouldDrop(const std::string** category_name) const; + private: + DropCategoryList drop_category_list_; + }; - const DropCategoryList& drop_category_list() const { - return drop_category_list_; - } + struct EdsUpdate { + PriorityListUpdate priority_list_update; + RefCountedPtr drop_config; + bool drop_all = false; + }; - bool operator==(const XdsDropConfig& other) const { - return drop_category_list_ == other.drop_category_list_; - } - bool operator!=(const XdsDropConfig& other) const { - return !(*this == other); - } + using EdsUpdateMap = std::map; + + XdsApi(const XdsBootstrap::Node* node, const char* build_version) + : node_(node), build_version_(build_version) {} + + // Creates a request to nack an unsupported resource type. + // Takes ownership of \a error. + grpc_slice CreateUnsupportedTypeNackRequest(const std::string& type_url, + const std::string& nonce, + grpc_error* error); + + // Creates an LDS request querying \a server_name. + // Takes ownership of \a error. + grpc_slice CreateLdsRequest(const std::string& server_name, + const std::string& version, + const std::string& nonce, grpc_error* error, + bool populate_node); + + // Creates an RDS request querying \a route_config_name. + // Takes ownership of \a error. + grpc_slice CreateRdsRequest(const std::string& route_config_name, + const std::string& version, + const std::string& nonce, grpc_error* error, + bool populate_node); + + // Creates a CDS request querying \a cluster_names. + // Takes ownership of \a error. + grpc_slice CreateCdsRequest(const std::set& cluster_names, + const std::string& version, + const std::string& nonce, grpc_error* error, + bool populate_node); + + // Creates an EDS request querying \a eds_service_names. + // Takes ownership of \a error. + grpc_slice CreateEdsRequest(const std::set& eds_service_names, + const std::string& version, + const std::string& nonce, grpc_error* error, + bool populate_node); + + // Parses the ADS response and outputs the validated update for either CDS or + // EDS. If the response can't be parsed at the top level, \a type_url will + // point to an empty string; otherwise, it will point to the received data. + grpc_error* ParseAdsResponse( + const grpc_slice& encoded_response, + const std::string& expected_server_name, + const std::string& expected_route_config_name, + const std::set& expected_eds_service_names, + LdsUpdate* lds_update, RdsUpdate* rds_update, + CdsUpdateMap* cds_update_map, EdsUpdateMap* eds_update_map, + std::string* version, std::string* nonce, std::string* type_url); + + // Creates an LRS request querying \a server_name. + grpc_slice CreateLrsInitialRequest(const std::string& server_name); + + // Creates an LRS request sending client-side load reports. If all the + // counters are zero, returns empty slice. + grpc_slice CreateLrsRequest(std::map, StringLess> + client_stats_map); + + // Parses the LRS response and returns \a + // load_reporting_interval for client-side load reporting. If there is any + // error, the output config is invalid. + grpc_error* ParseLrsResponse(const grpc_slice& encoded_response, + std::set* cluster_names, + grpc_millis* load_reporting_interval); private: - DropCategoryList drop_category_list_; + const XdsBootstrap::Node* node_; + const char* build_version_; }; -struct EdsUpdate { - XdsPriorityListUpdate priority_list_update; - RefCountedPtr drop_config; - bool drop_all = false; -}; - -using EdsUpdateMap = std::map; - -// Creates a request to nack an unsupported resource type. -// Takes ownership of \a error. -grpc_slice XdsUnsupportedTypeNackRequestCreateAndEncode( - const std::string& type_url, const std::string& nonce, grpc_error* error); - -// Creates an LDS request querying \a server_name. -// Takes ownership of \a error. -grpc_slice XdsLdsRequestCreateAndEncode(const std::string& server_name, - const XdsBootstrap::Node* node, - const char* build_version, - const std::string& version, - const std::string& nonce, - grpc_error* error); - -// Creates an RDS request querying \a route_config_name. -// Takes ownership of \a error. -grpc_slice XdsRdsRequestCreateAndEncode(const std::string& route_config_name, - const XdsBootstrap::Node* node, - const char* build_version, - const std::string& version, - const std::string& nonce, - grpc_error* error); - -// Creates a CDS request querying \a cluster_names. -// Takes ownership of \a error. -grpc_slice XdsCdsRequestCreateAndEncode( - const std::set& cluster_names, const XdsBootstrap::Node* node, - const char* build_version, const std::string& version, - const std::string& nonce, grpc_error* error); - -// Creates an EDS request querying \a eds_service_names. -// Takes ownership of \a error. -grpc_slice XdsEdsRequestCreateAndEncode( - const std::set& eds_service_names, - const XdsBootstrap::Node* node, const char* build_version, - const std::string& version, const std::string& nonce, grpc_error* error); - -// Parses the ADS response and outputs the validated update for either CDS or -// EDS. If the response can't be parsed at the top level, \a type_url will point -// to an empty string; otherwise, it will point to the received data. -grpc_error* XdsAdsResponseDecodeAndParse( - const grpc_slice& encoded_response, const std::string& expected_server_name, - const std::string& expected_route_config_name, - const std::set& expected_eds_service_names, - LdsUpdate* lds_update, RdsUpdate* rds_update, CdsUpdateMap* cds_update_map, - EdsUpdateMap* eds_update_map, std::string* version, std::string* nonce, - std::string* type_url); - -// Creates an LRS request querying \a server_name. -grpc_slice XdsLrsRequestCreateAndEncode(const std::string& server_name, - const XdsBootstrap::Node* node, - const char* build_version); - -// Creates an LRS request sending client-side load reports. If all the counters -// are zero, returns empty slice. -grpc_slice XdsLrsRequestCreateAndEncode( - std::map, StringLess> - client_stats_map); - -// Parses the LRS response and returns \a -// load_reporting_interval for client-side load reporting. If there is any -// error, the output config is invalid. -grpc_error* XdsLrsResponseDecodeAndParse(const grpc_slice& encoded_response, - std::set* cluster_names, - grpc_millis* load_reporting_interval); - } // namespace grpc_core #endif /* GRPC_CORE_EXT_FILTERS_CLIENT_CHANNEL_XDS_XDS_API_H */ diff --git a/src/core/ext/filters/client_channel/xds/xds_client.cc b/src/core/ext/filters/client_channel/xds/xds_client.cc index ee965c01ee5..1d6b86ad65e 100644 --- a/src/core/ext/filters/client_channel/xds/xds_client.cc +++ b/src/core/ext/filters/client_channel/xds/xds_client.cc @@ -187,17 +187,18 @@ class XdsClient::ChannelState::AdsCallState gpr_log(GPR_INFO, "[xds_client %p] %s", self->ads_calld_->xds_client(), grpc_error_string(error)); } - if (self->type_url_ == kLdsTypeUrl || self->type_url_ == kRdsTypeUrl) { + if (self->type_url_ == XdsApi::kLdsTypeUrl || + self->type_url_ == XdsApi::kRdsTypeUrl) { self->ads_calld_->xds_client()->service_config_watcher_->OnError( error); - } else if (self->type_url_ == kCdsTypeUrl) { + } else if (self->type_url_ == XdsApi::kCdsTypeUrl) { ClusterState& state = self->ads_calld_->xds_client()->cluster_map_[self->name_]; for (const auto& p : state.watchers) { p.first->OnError(GRPC_ERROR_REF(error)); } GRPC_ERROR_UNREF(error); - } else if (self->type_url_ == kEdsTypeUrl) { + } else if (self->type_url_ == XdsApi::kEdsTypeUrl) { EndpointState& state = self->ads_calld_->xds_client()->endpoint_map_[self->name_]; for (const auto& p : state.watchers) { @@ -237,10 +238,10 @@ class XdsClient::ChannelState::AdsCallState void SendMessageLocked(const std::string& type_url); - void AcceptLdsUpdate(LdsUpdate lds_update); - void AcceptRdsUpdate(RdsUpdate rds_update); - void AcceptCdsUpdate(CdsUpdateMap cds_update_map); - void AcceptEdsUpdate(EdsUpdateMap eds_update_map); + void AcceptLdsUpdate(XdsApi::LdsUpdate lds_update); + void AcceptRdsUpdate(XdsApi::RdsUpdate rds_update); + void AcceptCdsUpdate(XdsApi::CdsUpdateMap cds_update_map); + void AcceptEdsUpdate(XdsApi::EdsUpdateMap eds_update_map); static void OnRequestSent(void* arg, grpc_error* error); static void OnRequestSentLocked(void* arg, grpc_error* error); @@ -710,13 +711,13 @@ XdsClient::ChannelState::AdsCallState::AdsCallState( GRPC_CLOSURE_INIT(&on_request_sent_, OnRequestSent, this, grpc_schedule_on_exec_ctx); if (xds_client()->service_config_watcher_ != nullptr) { - Subscribe(kLdsTypeUrl, xds_client()->server_name_); + Subscribe(XdsApi::kLdsTypeUrl, xds_client()->server_name_); } for (const auto& p : xds_client()->cluster_map_) { - Subscribe(kCdsTypeUrl, std::string(p.first)); + Subscribe(XdsApi::kCdsTypeUrl, std::string(p.first)); } for (const auto& p : xds_client()->endpoint_map_) { - Subscribe(kEdsTypeUrl, std::string(p.first)); + Subscribe(XdsApi::kEdsTypeUrl, std::string(p.first)); } // Op: recv initial metadata. op = ops; @@ -789,35 +790,31 @@ void XdsClient::ChannelState::AdsCallState::SendMessageLocked( auto& state = state_map_[type_url]; grpc_error* error = state.error; state.error = GRPC_ERROR_NONE; - const XdsBootstrap::Node* node = - sent_initial_message_ ? nullptr : xds_client()->bootstrap_->node(); - const char* build_version = - sent_initial_message_ ? nullptr : xds_client()->build_version_.get(); - sent_initial_message_ = true; grpc_slice request_payload_slice; - if (type_url == kLdsTypeUrl) { - request_payload_slice = XdsLdsRequestCreateAndEncode( - xds_client()->server_name_, node, build_version, state.version, - state.nonce, error); + if (type_url == XdsApi::kLdsTypeUrl) { + request_payload_slice = xds_client()->api_.CreateLdsRequest( + xds_client()->server_name_, state.version, state.nonce, error, + !sent_initial_message_); state.subscribed_resources[xds_client()->server_name_]->Start(Ref()); - } else if (type_url == kRdsTypeUrl) { - request_payload_slice = XdsRdsRequestCreateAndEncode( - xds_client()->route_config_name_, node, build_version, state.version, - state.nonce, error); + } else if (type_url == XdsApi::kRdsTypeUrl) { + request_payload_slice = xds_client()->api_.CreateRdsRequest( + xds_client()->route_config_name_, state.version, state.nonce, error, + !sent_initial_message_); state.subscribed_resources[xds_client()->route_config_name_]->Start(Ref()); - } else if (type_url == kCdsTypeUrl) { - request_payload_slice = XdsCdsRequestCreateAndEncode( - ClusterNamesForRequest(), node, build_version, state.version, - state.nonce, error); - } else if (type_url == kEdsTypeUrl) { - request_payload_slice = XdsEdsRequestCreateAndEncode( - EdsServiceNamesForRequest(), node, build_version, state.version, - state.nonce, error); + } else if (type_url == XdsApi::kCdsTypeUrl) { + request_payload_slice = xds_client()->api_.CreateCdsRequest( + ClusterNamesForRequest(), state.version, state.nonce, error, + !sent_initial_message_); + } else if (type_url == XdsApi::kEdsTypeUrl) { + request_payload_slice = xds_client()->api_.CreateEdsRequest( + EdsServiceNamesForRequest(), state.version, state.nonce, error, + !sent_initial_message_); } else { - request_payload_slice = XdsUnsupportedTypeNackRequestCreateAndEncode( + request_payload_slice = xds_client()->api_.CreateUnsupportedTypeNackRequest( type_url, state.nonce, state.error); state_map_.erase(type_url); } + sent_initial_message_ = true; // Create message payload. send_message_payload_ = grpc_raw_byte_buffer_create(&request_payload_slice, 1); @@ -863,7 +860,7 @@ bool XdsClient::ChannelState::AdsCallState::HasSubscribedResources() const { } void XdsClient::ChannelState::AdsCallState::AcceptLdsUpdate( - LdsUpdate lds_update) { + XdsApi::LdsUpdate lds_update) { const std::string& cluster_name = lds_update.rds_update.has_value() ? lds_update.rds_update.value().cluster_name @@ -876,7 +873,7 @@ void XdsClient::ChannelState::AdsCallState::AcceptLdsUpdate( xds_client(), lds_update.route_config_name.c_str(), cluster_name.c_str()); } - auto& lds_state = state_map_[kLdsTypeUrl]; + auto& lds_state = state_map_[XdsApi::kLdsTypeUrl]; auto& state = lds_state.subscribed_resources[xds_client()->server_name_]; if (state != nullptr) state->Finish(); // Ignore identical update. @@ -906,19 +903,19 @@ void XdsClient::ChannelState::AdsCallState::AcceptLdsUpdate( } } else { // Send RDS request for dynamic resolution. - Subscribe(kRdsTypeUrl, xds_client()->route_config_name_); + Subscribe(XdsApi::kRdsTypeUrl, xds_client()->route_config_name_); } } void XdsClient::ChannelState::AdsCallState::AcceptRdsUpdate( - RdsUpdate rds_update) { + XdsApi::RdsUpdate rds_update) { if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { gpr_log(GPR_INFO, "[xds_client %p] RDS update received: " "cluster_name=%s", xds_client(), rds_update.cluster_name.c_str()); } - auto& rds_state = state_map_[kRdsTypeUrl]; + auto& rds_state = state_map_[XdsApi::kRdsTypeUrl]; auto& state = rds_state.subscribed_resources[xds_client()->route_config_name_]; if (state != nullptr) state->Finish(); @@ -945,11 +942,11 @@ void XdsClient::ChannelState::AdsCallState::AcceptRdsUpdate( } void XdsClient::ChannelState::AdsCallState::AcceptCdsUpdate( - CdsUpdateMap cds_update_map) { - auto& cds_state = state_map_[kCdsTypeUrl]; + XdsApi::CdsUpdateMap cds_update_map) { + auto& cds_state = state_map_[XdsApi::kCdsTypeUrl]; for (auto& p : cds_update_map) { const char* cluster_name = p.first.c_str(); - CdsUpdate& cds_update = p.second; + XdsApi::CdsUpdate& cds_update = p.second; auto& state = cds_state.subscribed_resources[cluster_name]; if (state != nullptr) state->Finish(); if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { @@ -987,11 +984,11 @@ void XdsClient::ChannelState::AdsCallState::AcceptCdsUpdate( } void XdsClient::ChannelState::AdsCallState::AcceptEdsUpdate( - EdsUpdateMap eds_update_map) { - auto& eds_state = state_map_[kEdsTypeUrl]; + XdsApi::EdsUpdateMap eds_update_map) { + auto& eds_state = state_map_[XdsApi::kEdsTypeUrl]; for (auto& p : eds_update_map) { const char* eds_service_name = p.first.c_str(); - EdsUpdate& eds_update = p.second; + XdsApi::EdsUpdate& eds_update = p.second; auto& state = eds_state.subscribed_resources[eds_service_name]; if (state != nullptr) state->Finish(); if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { @@ -1015,9 +1012,9 @@ void XdsClient::ChannelState::AdsCallState::AcceptEdsUpdate( const auto& locality = p.second; gpr_log(GPR_INFO, "[xds_client %p] Priority %" PRIuPTR ", locality %" PRIuPTR - " %s contains %" PRIuPTR " server addresses", + " %s has weight %d, contains %" PRIuPTR " server addresses", xds_client(), priority, locality_count, - locality.name->AsHumanReadableString(), + locality.name->AsHumanReadableString(), locality.lb_weight, locality.serverlist.size()); for (size_t i = 0; i < locality.serverlist.size(); ++i) { char* ipport; @@ -1035,7 +1032,7 @@ void XdsClient::ChannelState::AdsCallState::AcceptEdsUpdate( } for (size_t i = 0; i < eds_update.drop_config->drop_category_list().size(); ++i) { - const XdsDropConfig::DropCategory& drop_category = + const XdsApi::DropConfig::DropCategory& drop_category = eds_update.drop_config->drop_category_list()[i]; gpr_log(GPR_INFO, "[xds_client %p] Drop category %s has drop rate %d per million", @@ -1046,7 +1043,7 @@ void XdsClient::ChannelState::AdsCallState::AcceptEdsUpdate( EndpointState& endpoint_state = xds_client()->endpoint_map_[eds_service_name]; // Ignore identical update. - const EdsUpdate& prev_update = endpoint_state.update; + const XdsApi::EdsUpdate& prev_update = endpoint_state.update; const bool priority_list_changed = prev_update.priority_list_update != eds_update.priority_list_update; const bool drop_config_changed = @@ -1138,15 +1135,15 @@ void XdsClient::ChannelState::AdsCallState::OnResponseReceivedLocked( // mode. We will also need to cancel the timer when we receive a serverlist // from the balancer. // Parse the response. - LdsUpdate lds_update; - RdsUpdate rds_update; - CdsUpdateMap cds_update_map; - EdsUpdateMap eds_update_map; + XdsApi::LdsUpdate lds_update; + XdsApi::RdsUpdate rds_update; + XdsApi::CdsUpdateMap cds_update_map; + XdsApi::EdsUpdateMap eds_update_map; std::string version; std::string nonce; std::string type_url; - // Note that XdsAdsResponseDecodeAndParse() also validate the response. - grpc_error* parse_error = XdsAdsResponseDecodeAndParse( + // Note that ParseAdsResponse() also validates the response. + grpc_error* parse_error = xds_client->api_.ParseAdsResponse( response_slice, xds_client->server_name_, xds_client->route_config_name_, ads_calld->EdsServiceNamesForRequest(), &lds_update, &rds_update, &cds_update_map, &eds_update_map, &version, &nonce, &type_url); @@ -1173,13 +1170,13 @@ void XdsClient::ChannelState::AdsCallState::OnResponseReceivedLocked( } else { ads_calld->seen_response_ = true; // Accept the ADS response according to the type_url. - if (type_url == kLdsTypeUrl) { + if (type_url == XdsApi::kLdsTypeUrl) { ads_calld->AcceptLdsUpdate(std::move(lds_update)); - } else if (type_url == kRdsTypeUrl) { + } else if (type_url == XdsApi::kRdsTypeUrl) { ads_calld->AcceptRdsUpdate(std::move(rds_update)); - } else if (type_url == kCdsTypeUrl) { + } else if (type_url == XdsApi::kCdsTypeUrl) { ads_calld->AcceptCdsUpdate(std::move(cds_update_map)); - } else if (type_url == kEdsTypeUrl) { + } else if (type_url == XdsApi::kEdsTypeUrl) { ads_calld->AcceptEdsUpdate(std::move(eds_update_map)); } state.version = std::move(version); @@ -1258,7 +1255,7 @@ bool XdsClient::ChannelState::AdsCallState::IsCurrentCallOnChannel() const { std::set XdsClient::ChannelState::AdsCallState::ClusterNamesForRequest() { std::set cluster_names; - for (auto& p : state_map_[kCdsTypeUrl].subscribed_resources) { + for (auto& p : state_map_[XdsApi::kCdsTypeUrl].subscribed_resources) { cluster_names.insert(p.first); OrphanablePtr& state = p.second; state->Start(Ref()); @@ -1269,7 +1266,7 @@ XdsClient::ChannelState::AdsCallState::ClusterNamesForRequest() { std::set XdsClient::ChannelState::AdsCallState::EdsServiceNamesForRequest() { std::set eds_names; - for (auto& p : state_map_[kEdsTypeUrl].subscribed_resources) { + for (auto& p : state_map_[XdsApi::kEdsTypeUrl].subscribed_resources) { eds_names.insert(p.first); OrphanablePtr& state = p.second; state->Start(Ref()); @@ -1320,7 +1317,7 @@ void XdsClient::ChannelState::LrsCallState::Reporter::OnNextReportTimerLocked( void XdsClient::ChannelState::LrsCallState::Reporter::SendReportLocked() { // Create a request that contains the load report. grpc_slice request_payload_slice = - XdsLrsRequestCreateAndEncode(xds_client()->ClientStatsMap()); + xds_client()->api_.CreateLrsRequest(xds_client()->ClientStatsMap()); // Skip client load report if the counters were all zero in the last // report and they are still zero in this one. const bool old_val = last_report_counters_were_zero_; @@ -1396,9 +1393,8 @@ XdsClient::ChannelState::LrsCallState::LrsCallState( nullptr, GRPC_MILLIS_INF_FUTURE, nullptr); GPR_ASSERT(call_ != nullptr); // Init the request payload. - grpc_slice request_payload_slice = XdsLrsRequestCreateAndEncode( - xds_client()->server_name_, xds_client()->bootstrap_->node(), - xds_client()->build_version_.get()); + grpc_slice request_payload_slice = + xds_client()->api_.CreateLrsInitialRequest(xds_client()->server_name_); send_message_payload_ = grpc_raw_byte_buffer_create(&request_payload_slice, 1); grpc_slice_unref_internal(request_payload_slice); @@ -1577,7 +1573,7 @@ void XdsClient::ChannelState::LrsCallState::OnResponseReceivedLocked( // Parse the response. std::set new_cluster_names; grpc_millis new_load_reporting_interval; - grpc_error* parse_error = XdsLrsResponseDecodeAndParse( + grpc_error* parse_error = xds_client->api_.ParseLrsResponse( response_slice, &new_cluster_names, &new_load_reporting_interval); if (parse_error != GRPC_ERROR_NONE) { gpr_log(GPR_ERROR, @@ -1722,6 +1718,8 @@ XdsClient::XdsClient(Combiner* combiner, grpc_pollset_set* interested_parties, combiner_(GRPC_COMBINER_REF(combiner, "xds_client")), interested_parties_(interested_parties), bootstrap_(XdsBootstrap::ReadFromFile(error)), + api_(bootstrap_ == nullptr ? nullptr : bootstrap_->node(), + build_version_.get()), server_name_(server_name), service_config_watcher_(std::move(watcher)) { if (*error != GRPC_ERROR_NONE) { @@ -1744,7 +1742,7 @@ XdsClient::XdsClient(Combiner* combiner, grpc_pollset_set* interested_parties, chand_ = MakeOrphanable( Ref(DEBUG_LOCATION, "XdsClient+ChannelState"), channel); if (service_config_watcher_ != nullptr) { - chand_->Subscribe(kLdsTypeUrl, std::string(server_name)); + chand_->Subscribe(XdsApi::kLdsTypeUrl, std::string(server_name)); } } @@ -1769,7 +1767,7 @@ void XdsClient::WatchClusterData( if (cluster_state.update.has_value()) { w->OnClusterChanged(cluster_state.update.value()); } - chand_->Subscribe(kCdsTypeUrl, cluster_name_str); + chand_->Subscribe(XdsApi::kCdsTypeUrl, cluster_name_str); } void XdsClient::CancelClusterDataWatch(StringView cluster_name, @@ -1782,7 +1780,7 @@ void XdsClient::CancelClusterDataWatch(StringView cluster_name, cluster_state.watchers.erase(it); if (cluster_state.watchers.empty()) { cluster_map_.erase(cluster_name_str); - chand_->Unsubscribe(kCdsTypeUrl, cluster_name_str); + chand_->Unsubscribe(XdsApi::kCdsTypeUrl, cluster_name_str); } } } @@ -1799,7 +1797,7 @@ void XdsClient::WatchEndpointData( if (!endpoint_state.update.priority_list_update.empty()) { w->OnEndpointChanged(endpoint_state.update); } - chand_->Subscribe(kEdsTypeUrl, eds_service_name_str); + chand_->Subscribe(XdsApi::kEdsTypeUrl, eds_service_name_str); } void XdsClient::CancelEndpointDataWatch(StringView eds_service_name, @@ -1812,7 +1810,7 @@ void XdsClient::CancelEndpointDataWatch(StringView eds_service_name, endpoint_state.watchers.erase(it); if (endpoint_state.watchers.empty()) { endpoint_map_.erase(eds_service_name_str); - chand_->Unsubscribe(kEdsTypeUrl, eds_service_name_str); + chand_->Unsubscribe(XdsApi::kEdsTypeUrl, eds_service_name_str); } } } diff --git a/src/core/ext/filters/client_channel/xds/xds_client.h b/src/core/ext/filters/client_channel/xds/xds_client.h index 4cf4bc8222e..9b0c76c3365 100644 --- a/src/core/ext/filters/client_channel/xds/xds_client.h +++ b/src/core/ext/filters/client_channel/xds/xds_client.h @@ -56,7 +56,7 @@ class XdsClient : public InternallyRefCounted { public: virtual ~ClusterWatcherInterface() = default; - virtual void OnClusterChanged(CdsUpdate cluster_data) = 0; + virtual void OnClusterChanged(XdsApi::CdsUpdate cluster_data) = 0; virtual void OnError(grpc_error* error) = 0; }; @@ -66,7 +66,7 @@ class XdsClient : public InternallyRefCounted { public: virtual ~EndpointWatcherInterface() = default; - virtual void OnEndpointChanged(EdsUpdate update) = 0; + virtual void OnEndpointChanged(XdsApi::EdsUpdate update) = 0; virtual void OnError(grpc_error* error) = 0; }; @@ -175,7 +175,7 @@ class XdsClient : public InternallyRefCounted { std::map> watchers; // The latest data seen from CDS. - Optional update; + Optional update; }; struct EndpointState { @@ -184,7 +184,7 @@ class XdsClient : public InternallyRefCounted { watchers; std::set client_stats; // The latest data seen from EDS. - EdsUpdate update; + XdsApi::EdsUpdate update; }; // Sends an error notification to all watchers. @@ -212,6 +212,7 @@ class XdsClient : public InternallyRefCounted { grpc_pollset_set* interested_parties_; std::unique_ptr bootstrap_; + XdsApi api_; const std::string server_name_; diff --git a/src/objective-c/grpc_objc_internal_library.bzl b/src/objective-c/grpc_objc_internal_library.bzl index 5e355e0cb26..ad212d13715 100644 --- a/src/objective-c/grpc_objc_internal_library.bzl +++ b/src/objective-c/grpc_objc_internal_library.bzl @@ -23,6 +23,7 @@ # each change must be ported from one to the other. # +load("@rules_proto//proto:defs.bzl", "proto_library") load( "//bazel:generate_objc.bzl", "generate_objc", @@ -39,7 +40,7 @@ def proto_library_objc_wrapper( """proto_library for adding dependencies to google/protobuf protos use_well_known_protos - ignored in open source version """ - native.proto_library( + proto_library( name = name, srcs = srcs, deps = deps, diff --git a/src/proto/grpc/channelz/BUILD b/src/proto/grpc/channelz/BUILD index d105ddb261d..6aa9c12385f 100644 --- a/src/proto/grpc/channelz/BUILD +++ b/src/proto/grpc/channelz/BUILD @@ -14,6 +14,7 @@ licenses(["notice"]) # Apache v2 +load("@rules_proto//proto:defs.bzl", "proto_library") load("//bazel:grpc_build_system.bzl", "grpc_package", "grpc_proto_library") grpc_package( diff --git a/src/proto/grpc/gcp/BUILD b/src/proto/grpc/gcp/BUILD index 1c22d89e464..96e287f7009 100644 --- a/src/proto/grpc/gcp/BUILD +++ b/src/proto/grpc/gcp/BUILD @@ -14,6 +14,8 @@ licenses(["notice"]) # Apache v2 +load("@rules_proto//proto:defs.bzl", "proto_library") + proto_library( name = "alts_handshaker_proto", srcs = [ diff --git a/src/proto/grpc/health/v1/BUILD b/src/proto/grpc/health/v1/BUILD index 8acc1328aba..9caa531c5a7 100644 --- a/src/proto/grpc/health/v1/BUILD +++ b/src/proto/grpc/health/v1/BUILD @@ -14,6 +14,7 @@ licenses(["notice"]) # Apache v2 +load("@rules_proto//proto:defs.bzl", "proto_library") load("//bazel:grpc_build_system.bzl", "grpc_package", "grpc_proto_library") grpc_package( diff --git a/src/proto/grpc/lb/v1/BUILD b/src/proto/grpc/lb/v1/BUILD index 2a6e82a57e7..44ffebe18a2 100644 --- a/src/proto/grpc/lb/v1/BUILD +++ b/src/proto/grpc/lb/v1/BUILD @@ -14,6 +14,7 @@ licenses(["notice"]) # Apache v2 +load("@rules_proto//proto:defs.bzl", "proto_library") load("//bazel:grpc_build_system.bzl", "grpc_package", "grpc_proto_library") grpc_package( diff --git a/src/proto/grpc/reflection/v1alpha/BUILD b/src/proto/grpc/reflection/v1alpha/BUILD index 96f2a8ec598..e7824558fa1 100644 --- a/src/proto/grpc/reflection/v1alpha/BUILD +++ b/src/proto/grpc/reflection/v1alpha/BUILD @@ -14,6 +14,7 @@ licenses(["notice"]) # Apache v2 +load("@rules_proto//proto:defs.bzl", "proto_library") load("//bazel:grpc_build_system.bzl", "grpc_package", "grpc_proto_library") grpc_package( diff --git a/src/proto/grpc/testing/BUILD b/src/proto/grpc/testing/BUILD index 9d71dfbd953..db187e2356a 100644 --- a/src/proto/grpc/testing/BUILD +++ b/src/proto/grpc/testing/BUILD @@ -14,8 +14,9 @@ licenses(["notice"]) # Apache v2 -load("//bazel:grpc_build_system.bzl", "grpc_package", "grpc_proto_library") load("@grpc_python_dependencies//:requirements.bzl", "requirement") +load("@rules_proto//proto:defs.bzl", "proto_library") +load("//bazel:grpc_build_system.bzl", "grpc_package", "grpc_proto_library") load("//bazel:python_rules.bzl", "py_grpc_library", "py_proto_library") grpc_package( diff --git a/src/proto/grpc/testing/messages.proto b/src/proto/grpc/testing/messages.proto index cda7753c7aa..5993bc6bf14 100644 --- a/src/proto/grpc/testing/messages.proto +++ b/src/proto/grpc/testing/messages.proto @@ -115,6 +115,9 @@ message SimpleResponse { string server_id = 4; // gRPCLB Path. GrpclbRouteType grpclb_route_type = 5; + + // Server hostname. + string hostname = 6; } // Client-streaming request. @@ -190,3 +193,17 @@ message ReconnectInfo { bool passed = 1; repeated int32 backoff_ms = 2; } + +message LoadBalancerStatsRequest { + // Request stats for the next num_rpcs sent by client. + int32 num_rpcs = 1; + // If num_rpcs have not completed within timeout_sec, return partial results. + int32 timeout_sec = 2; +} + +message LoadBalancerStatsResponse { + // The number of completed RPCs for each peer. + map rpcs_by_peer = 1; + // The number of RPCs that failed to record a remote peer. + int32 num_failures = 2; +} diff --git a/src/proto/grpc/testing/proto2/BUILD.bazel b/src/proto/grpc/testing/proto2/BUILD.bazel index 05a4b55f5a6..a28c1b2a672 100644 --- a/src/proto/grpc/testing/proto2/BUILD.bazel +++ b/src/proto/grpc/testing/proto2/BUILD.bazel @@ -1,3 +1,4 @@ +load("@rules_proto//proto:defs.bzl", "proto_library") load("@grpc_python_dependencies//:requirements.bzl", "requirement") package(default_visibility = ["//visibility:public"]) diff --git a/src/proto/grpc/testing/test.proto b/src/proto/grpc/testing/test.proto index c049c8fa079..0b198d8c260 100644 --- a/src/proto/grpc/testing/test.proto +++ b/src/proto/grpc/testing/test.proto @@ -77,3 +77,10 @@ service ReconnectService { rpc Start(grpc.testing.ReconnectParams) returns (grpc.testing.Empty); rpc Stop(grpc.testing.Empty) returns (grpc.testing.ReconnectInfo); } + +// A service used to obtain stats for verifying LB behavior. +service LoadBalancerStatsService { + // Gets the backend distribution for RPCs sent by a test client. + rpc GetClientStats(LoadBalancerStatsRequest) + returns (LoadBalancerStatsResponse) {} +} diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi index 19e0467ec7a..e70122d65e1 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi @@ -367,7 +367,8 @@ cdef class _AioCall(GrpcCallWrapper): """Sends one single raw message in bytes.""" await _send_message(self, message, - True, + None, + False, self._loop) async def send_receive_close(self): diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi index 69f3fcffbbf..67848cadaf8 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi @@ -66,7 +66,7 @@ cdef class CallbackWrapper: cdef CallbackFailureHandler CQ_SHUTDOWN_FAILURE_HANDLER = CallbackFailureHandler( 'grpc_completion_queue_shutdown', 'Unknown', - RuntimeError) + InternalError) cdef class CallbackCompletionQueue: @@ -153,12 +153,13 @@ async def _receive_message(GrpcCallWrapper grpc_call_wrapper, async def _send_message(GrpcCallWrapper grpc_call_wrapper, bytes message, - bint metadata_sent, + Operation send_initial_metadata_op, + int write_flag, object loop): - cdef SendMessageOperation op = SendMessageOperation(message, _EMPTY_FLAG) + cdef SendMessageOperation op = SendMessageOperation(message, write_flag) cdef tuple ops = (op,) - if not metadata_sent: - ops = prepend_send_initial_metadata_op(ops, None) + if send_initial_metadata_op is not None: + ops = (send_initial_metadata_op,) + ops await execute_batch(grpc_call_wrapper, ops, loop) @@ -184,7 +185,7 @@ async def _send_error_status_from_server(GrpcCallWrapper grpc_call_wrapper, grpc_status_code code, str details, tuple trailing_metadata, - bint metadata_sent, + Operation send_initial_metadata_op, object loop): assert code != StatusCode.ok, 'Expecting non-ok status code.' cdef SendStatusFromServerOperation op = SendStatusFromServerOperation( @@ -194,6 +195,6 @@ async def _send_error_status_from_server(GrpcCallWrapper grpc_call_wrapper, _EMPTY_FLAGS, ) cdef tuple ops = (op,) - if not metadata_sent: - ops = prepend_send_initial_metadata_op(ops, None) + if send_initial_metadata_op is not None: + ops = (send_initial_metadata_op,) + ops await execute_batch(grpc_call_wrapper, ops, loop) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi index 38cb8887350..bfa9477b6d1 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi @@ -71,8 +71,7 @@ cdef class AioChannel: other design of API if necessary. """ if self._status in (AIO_CHANNEL_STATUS_DESTROYED, AIO_CHANNEL_STATUS_CLOSING): - # TODO(lidiz) switch to UsageError - raise RuntimeError('Channel is closed.') + raise UsageError('Channel is closed.') cdef gpr_timespec c_deadline = _timespec_from_time(deadline) @@ -115,8 +114,7 @@ cdef class AioChannel: The _AioCall object. """ if self.closed(): - # TODO(lidiz) switch to UsageError - raise RuntimeError('Channel is closed.') + raise UsageError('Channel is closed.') cdef CallCredentials cython_call_credentials if python_call_credentials is not None: diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi index 76cc2996552..cf9269364ce 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi @@ -67,3 +67,33 @@ class _EOF: EOF = _EOF() + +_COMPRESSION_METADATA_STRING_MAPPING = { + CompressionAlgorithm.none: 'identity', + CompressionAlgorithm.deflate: 'deflate', + CompressionAlgorithm.gzip: 'gzip', +} + +class BaseError(Exception): + """The base class for exceptions generated by gRPC AsyncIO stack.""" + + +class UsageError(BaseError): + """Raised when the usage of API by applications is inappropriate. + + For example, trying to invoke RPC on a closed channel, mixing two styles + of streaming API on the client side. This exception should not be + suppressed. + """ + + +class AbortError(BaseError): + """Raised when calling abort in servicer methods. + + This exception should not be suppressed. Applications may catch it to + perform certain clean-up logic, and then re-raise it. + """ + + +class InternalError(BaseError): + """Raised upon unexpected errors in native code.""" diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi index f6a6f4d7094..d3edb70dafe 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi @@ -31,10 +31,14 @@ cdef class RPCState(GrpcCallWrapper): cdef grpc_status_code status_code cdef str status_details cdef tuple trailing_metadata + cdef object compression_algorithm + cdef bint disable_next_compression cdef bytes method(self) cdef tuple invocation_metadata(self) cdef void raise_for_termination(self) except * + cdef int get_write_flag(self) + cdef Operation create_send_initial_metadata_op_if_not_sent(self) cdef enum AioServerStatus: diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi index ec72624d261..903c20796f7 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi @@ -21,13 +21,23 @@ cdef int _EMPTY_FLAG = 0 cdef str _RPC_FINISHED_DETAILS = 'RPC already finished.' cdef str _SERVER_STOPPED_DETAILS = 'Server already stopped.' +cdef _augment_metadata(tuple metadata, object compression): + if compression is None: + return metadata + else: + return (( + GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY, + _COMPRESSION_METADATA_STRING_MAPPING[compression] + ),) + metadata + + cdef class _HandlerCallDetails: def __cinit__(self, str method, tuple invocation_metadata): self.method = method self.invocation_metadata = invocation_metadata -class _ServerStoppedError(RuntimeError): +class _ServerStoppedError(BaseError): """Raised if the server is stopped.""" @@ -45,6 +55,8 @@ cdef class RPCState: self.status_code = StatusCode.ok self.status_details = '' self.trailing_metadata = _IMMUTABLE_EMPTY_METADATA + self.compression_algorithm = None + self.disable_next_compression = False cdef bytes method(self): return _slice_bytes(self.details.method) @@ -65,10 +77,28 @@ cdef class RPCState: if self.abort_exception is not None: raise self.abort_exception if self.status_sent: - raise RuntimeError(_RPC_FINISHED_DETAILS) + raise UsageError(_RPC_FINISHED_DETAILS) if self.server._status == AIO_SERVER_STATUS_STOPPED: raise _ServerStoppedError(_SERVER_STOPPED_DETAILS) + cdef int get_write_flag(self): + if self.disable_next_compression: + self.disable_next_compression = False + return WriteFlag.no_compress + else: + return _EMPTY_FLAG + + cdef Operation create_send_initial_metadata_op_if_not_sent(self): + cdef SendInitialMetadataOperation op + if self.metadata_sent: + return None + else: + op = SendInitialMetadataOperation( + _augment_metadata(_IMMUTABLE_EMPTY_METADATA, self.compression_algorithm), + _EMPTY_FLAG + ) + return op + def __dealloc__(self): """Cleans the Core objects.""" grpc_call_details_destroy(&self.details) @@ -77,11 +107,6 @@ cdef class RPCState: grpc_call_unref(self.call) -# TODO(lidiz) inherit this from Python level `AioRpcStatus`, we need to improve -# current code structure to make it happen. -class AbortError(Exception): pass - - cdef class _ServicerContext: cdef RPCState _rpc_state cdef object _loop @@ -116,18 +141,23 @@ cdef class _ServicerContext: await _send_message(self._rpc_state, serialize(self._response_serializer, message), - self._rpc_state.metadata_sent, + self._rpc_state.create_send_initial_metadata_op_if_not_sent(), + self._rpc_state.get_write_flag(), self._loop) - if not self._rpc_state.metadata_sent: - self._rpc_state.metadata_sent = True + self._rpc_state.metadata_sent = True async def send_initial_metadata(self, tuple metadata): self._rpc_state.raise_for_termination() if self._rpc_state.metadata_sent: - raise RuntimeError('Send initial metadata failed: already sent') + raise UsageError('Send initial metadata failed: already sent') else: - await _send_initial_metadata(self._rpc_state, metadata, _EMPTY_FLAG, self._loop) + await _send_initial_metadata( + self._rpc_state, + _augment_metadata(metadata, self._rpc_state.compression_algorithm), + _EMPTY_FLAG, + self._loop + ) self._rpc_state.metadata_sent = True async def abort(self, @@ -135,7 +165,7 @@ cdef class _ServicerContext: str details='', tuple trailing_metadata=_IMMUTABLE_EMPTY_METADATA): if self._rpc_state.abort_exception is not None: - raise RuntimeError('Abort already called!') + raise UsageError('Abort already called!') else: # Keeps track of the exception object. After abort happen, the RPC # should stop execution. However, if users decided to suppress it, it @@ -156,7 +186,7 @@ cdef class _ServicerContext: actual_code, details, trailing_metadata, - self._rpc_state.metadata_sent, + self._rpc_state.create_send_initial_metadata_op_if_not_sent(), self._loop ) @@ -174,6 +204,15 @@ cdef class _ServicerContext: def set_details(self, str details): self._rpc_state.status_details = details + def set_compression(self, object compression): + if self._rpc_state.metadata_sent: + raise RuntimeError('Compression setting must be specified before sending initial metadata') + else: + self._rpc_state.compression_algorithm = compression + + def disable_next_message_compression(self): + self._rpc_state.disable_next_compression = True + cdef _find_method_handler(str method, tuple metadata, list generic_handlers): cdef _HandlerCallDetails handler_call_details = _HandlerCallDetails(method, @@ -217,7 +256,7 @@ async def _finish_handler_with_unary_response(RPCState rpc_state, # Assembles the batch operations cdef tuple finish_ops finish_ops = ( - SendMessageOperation(response_raw, _EMPTY_FLAGS), + SendMessageOperation(response_raw, rpc_state.get_write_flag()), SendStatusFromServerOperation( rpc_state.trailing_metadata, rpc_state.status_code, @@ -446,7 +485,7 @@ async def _handle_exceptions(RPCState rpc_state, object rpc_coro, object loop): status_code, 'Unexpected %s: %s' % (type(e), e), rpc_state.trailing_metadata, - rpc_state.metadata_sent, + rpc_state.create_send_initial_metadata_op_if_not_sent(), loop ) @@ -492,7 +531,7 @@ async def _handle_rpc(list generic_handlers, RPCState rpc_state, object loop): StatusCode.unimplemented, 'Method not found!', _IMMUTABLE_EMPTY_METADATA, - rpc_state.metadata_sent, + rpc_state.create_send_initial_metadata_op_if_not_sent(), loop ) return @@ -535,13 +574,13 @@ cdef CallbackFailureHandler REQUEST_CALL_FAILURE_HANDLER = CallbackFailureHandle cdef CallbackFailureHandler SERVER_SHUTDOWN_FAILURE_HANDLER = CallbackFailureHandler( 'grpc_server_shutdown_and_notify', None, - RuntimeError) + InternalError) cdef class AioServer: def __init__(self, loop, thread_pool, generic_handlers, interceptors, - options, maximum_concurrent_rpcs, compression): + options, maximum_concurrent_rpcs): # NOTE(lidiz) Core objects won't be deallocated automatically. # If AioServer.shutdown is not called, those objects will leak. self._server = Server(options) @@ -570,8 +609,6 @@ cdef class AioServer: raise NotImplementedError() if maximum_concurrent_rpcs: raise NotImplementedError() - if compression: - raise NotImplementedError() if thread_pool: raise NotImplementedError() @@ -600,7 +637,7 @@ cdef class AioServer: wrapper.c_functor() ) if error != GRPC_CALL_OK: - raise RuntimeError("Error in grpc_server_request_call: %s" % error) + raise InternalError("Error in grpc_server_request_call: %s" % error) await future return rpc_state @@ -650,7 +687,7 @@ cdef class AioServer: if self._status == AIO_SERVER_STATUS_RUNNING: return elif self._status != AIO_SERVER_STATUS_READY: - raise RuntimeError('Server not in ready state') + raise UsageError('Server not in ready state') self._status = AIO_SERVER_STATUS_RUNNING cdef object server_started = self._loop.create_future() @@ -746,11 +783,7 @@ cdef class AioServer: return True def __dealloc__(self): - """Deallocation of Core objects are ensured by Python grpc.aio.Server. - - If the Cython representation is deallocated without underlying objects - freed, raise an RuntimeError. - """ + """Deallocation of Core objects are ensured by Python layer.""" # TODO(lidiz) if users create server, and then dealloc it immediately. # There is a potential memory leak of created Core server. if self._status != AIO_SERVER_STATUS_STOPPED: diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi index a3be6bae479..00311b5ea2a 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi @@ -118,7 +118,7 @@ cdef class Server: def cancel_all_calls(self): if not self.is_shutting_down: - raise RuntimeError("the server must be shutting down to cancel all calls") + raise UsageError("the server must be shutting down to cancel all calls") elif self.is_shutdown: return else: @@ -136,7 +136,7 @@ cdef class Server: pass elif not self.is_shutting_down: if self.backup_shutdown_queue is None: - raise RuntimeError('Server shutdown failed: no completion queue.') + raise InternalError('Server shutdown failed: no completion queue.') else: # the user didn't call shutdown - use our backup queue self._c_shutdown(self.backup_shutdown_queue, None) diff --git a/src/python/grpcio/grpc/experimental/aio/__init__.py b/src/python/grpcio/grpc/experimental/aio/__init__.py index 0839c79010d..d8d284780c1 100644 --- a/src/python/grpcio/grpc/experimental/aio/__init__.py +++ b/src/python/grpcio/grpc/experimental/aio/__init__.py @@ -17,12 +17,11 @@ gRPC Async API objects may only be used on the thread on which they were created. AsyncIO doesn't provide thread safety for most of its APIs. """ -import abc -from typing import Any, Optional, Sequence, Text, Tuple -import six +from typing import Any, Optional, Sequence, Tuple import grpc -from grpc._cython.cygrpc import EOF, AbortError, init_grpc_aio +from grpc._cython.cygrpc import (EOF, AbortError, BaseError, UsageError, + init_grpc_aio) from ._base_call import Call, RpcContext, UnaryStreamCall, UnaryUnaryCall from ._call import AioRpcError @@ -34,7 +33,7 @@ from ._typing import ChannelArgumentType def insecure_channel( - target: Text, + target: str, options: Optional[ChannelArgumentType] = None, compression: Optional[grpc.Compression] = None, interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] = None): @@ -57,7 +56,7 @@ def insecure_channel( def secure_channel( - target: Text, + target: str, credentials: grpc.ChannelCredentials, options: Optional[ChannelArgumentType] = None, compression: Optional[grpc.Compression] = None, @@ -88,4 +87,4 @@ __all__ = ('AioRpcError', 'RpcContext', 'Call', 'UnaryUnaryCall', 'UnaryUnaryMultiCallable', 'ClientCallDetails', 'UnaryUnaryClientInterceptor', 'InterceptedUnaryUnaryCall', 'insecure_channel', 'server', 'Server', 'EOF', 'secure_channel', - 'AbortError') + 'AbortError', 'BaseError', 'UsageError') diff --git a/src/python/grpcio/grpc/experimental/aio/_base_call.py b/src/python/grpcio/grpc/experimental/aio/_base_call.py index 318a1edfcc1..d116982aa79 100644 --- a/src/python/grpcio/grpc/experimental/aio/_base_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_base_call.py @@ -19,7 +19,7 @@ RPC, e.g. cancellation. """ from abc import ABCMeta, abstractmethod -from typing import AsyncIterable, Awaitable, Generic, Optional, Text, Union +from typing import AsyncIterable, Awaitable, Generic, Optional, Union import grpc @@ -110,7 +110,7 @@ class Call(RpcContext, metaclass=ABCMeta): """ @abstractmethod - async def details(self) -> Text: + async def details(self) -> str: """Accesses the details sent by the server. Returns: diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index e237cc8085e..d06cc18d872 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -16,6 +16,7 @@ import asyncio from functools import partial import logging +import enum from typing import AsyncIterable, Awaitable, Dict, Optional import grpc @@ -143,9 +144,13 @@ class AioRpcError(grpc.RpcError): def _create_rpc_error(initial_metadata: Optional[MetadataType], status: cygrpc.AioRpcStatus) -> AioRpcError: - return AioRpcError(_common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()], - status.details(), initial_metadata, - status.trailing_metadata()) + return AioRpcError( + _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()], + status.details(), + initial_metadata, + status.trailing_metadata(), + status.debug_error_string(), + ) class Call: @@ -234,6 +239,12 @@ class Call: return self._repr() +class _APIStyle(enum.IntEnum): + UNKNOWN = 0 + ASYNC_GENERATOR = 1 + READER_WRITER = 2 + + class _UnaryResponseMixin(Call): _call_response: asyncio.Task @@ -279,10 +290,19 @@ class _UnaryResponseMixin(Call): class _StreamResponseMixin(Call): _message_aiter: AsyncIterable[ResponseType] _preparation: asyncio.Task + _response_style: _APIStyle def _init_stream_response_mixin(self, preparation: asyncio.Task): self._message_aiter = None self._preparation = preparation + self._response_style = _APIStyle.UNKNOWN + + def _update_response_style(self, style: _APIStyle): + if self._response_style is _APIStyle.UNKNOWN: + self._response_style = style + elif self._response_style is not style: + raise cygrpc.UsageError( + 'Please don\'t mix two styles of API for streaming responses') def cancel(self) -> bool: if super().cancel(): @@ -298,6 +318,7 @@ class _StreamResponseMixin(Call): message = await self._read() def __aiter__(self) -> AsyncIterable[ResponseType]: + self._update_response_style(_APIStyle.ASYNC_GENERATOR) if self._message_aiter is None: self._message_aiter = self._fetch_stream_responses() return self._message_aiter @@ -324,6 +345,7 @@ class _StreamResponseMixin(Call): if self.done(): await self._raise_for_status() return cygrpc.EOF + self._update_response_style(_APIStyle.READER_WRITER) response_message = await self._read() @@ -335,20 +357,28 @@ class _StreamResponseMixin(Call): class _StreamRequestMixin(Call): _metadata_sent: asyncio.Event - _done_writing: bool + _done_writing_flag: bool _async_request_poller: Optional[asyncio.Task] + _request_style: _APIStyle def _init_stream_request_mixin( self, request_async_iterator: Optional[AsyncIterable[RequestType]]): self._metadata_sent = asyncio.Event(loop=self._loop) - self._done_writing = False + self._done_writing_flag = False # If user passes in an async iterator, create a consumer Task. if request_async_iterator is not None: self._async_request_poller = self._loop.create_task( self._consume_request_iterator(request_async_iterator)) + self._request_style = _APIStyle.ASYNC_GENERATOR else: self._async_request_poller = None + self._request_style = _APIStyle.READER_WRITER + + def _raise_for_different_style(self, style: _APIStyle): + if self._request_style is not style: + raise cygrpc.UsageError( + 'Please don\'t mix two styles of API for streaming requests') def cancel(self) -> bool: if super().cancel(): @@ -365,8 +395,8 @@ class _StreamRequestMixin(Call): self, request_async_iterator: AsyncIterable[RequestType]) -> None: try: async for request in request_async_iterator: - await self.write(request) - await self.done_writing() + await self._write(request) + await self._done_writing() except AioRpcError as rpc_error: # Rpc status should be exposed through other API. Exceptions raised # within this Task won't be retrieved by another coroutine. It's @@ -374,10 +404,10 @@ class _StreamRequestMixin(Call): _LOGGER.debug('Exception while consuming the request_iterator: %s', rpc_error) - async def write(self, request: RequestType) -> None: + async def _write(self, request: RequestType) -> None: if self.done(): raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) - if self._done_writing: + if self._done_writing_flag: raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS) if not self._metadata_sent.is_set(): await self._metadata_sent.wait() @@ -394,14 +424,13 @@ class _StreamRequestMixin(Call): self.cancel() await self._raise_for_status() - async def done_writing(self) -> None: - """Implementation of done_writing is idempotent.""" + async def _done_writing(self) -> None: if self.done(): # If the RPC is finished, do nothing. return - if not self._done_writing: + if not self._done_writing_flag: # If the done writing is not sent before, try to send it. - self._done_writing = True + self._done_writing_flag = True try: await self._cython_call.send_receive_close() except asyncio.CancelledError: @@ -409,6 +438,18 @@ class _StreamRequestMixin(Call): self.cancel() await self._raise_for_status() + async def write(self, request: RequestType) -> None: + self._raise_for_different_style(_APIStyle.READER_WRITER) + await self._write(request) + + async def done_writing(self) -> None: + """Signal peer that client is done writing. + + This method is idempotent. + """ + self._raise_for_different_style(_APIStyle.READER_WRITER) + await self._done_writing() + class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall): """Object for managing unary-unary RPC calls. diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index 5d30e23fa80..7201aabcc73 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -13,13 +13,15 @@ # limitations under the License. """Invocation-side implementation of gRPC Asyncio Python.""" import asyncio -from typing import Any, AsyncIterable, Optional, Sequence, AbstractSet, Text +from typing import Any, AsyncIterable, Optional, Sequence, AbstractSet from weakref import WeakSet import logging import grpc from grpc import _common from grpc._cython import cygrpc +from grpc import _compression +from grpc import _grpcio_metadata from . import _base_call from ._call import (StreamStreamCall, StreamUnaryCall, UnaryStreamCall, @@ -31,6 +33,20 @@ from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType, from ._utils import _timeout_to_deadline _IMMUTABLE_EMPTY_TUPLE = tuple() +_USER_AGENT = 'grpc-python-asyncio/{}'.format(_grpcio_metadata.__version__) + + +def _augment_channel_arguments(base_options: ChannelArgumentType, + compression: Optional[grpc.Compression]): + compression_channel_argument = _compression.create_channel_option( + compression) + user_agent_channel_argument = (( + cygrpc.ChannelArgKey.primary_user_agent_string, + _USER_AGENT, + ),) + return tuple(base_options + ) + compression_channel_argument + user_agent_channel_argument + _LOGGER = logging.getLogger(__name__) @@ -110,7 +126,7 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable): request: Any, *, timeout: Optional[float] = None, - metadata: Optional[MetadataType] = None, + metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, compression: Optional[grpc.Compression] = None @@ -139,10 +155,7 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable): metadata, status code, and details. """ if compression: - raise NotImplementedError("TODO: compression not implemented yet") - - if metadata is None: - metadata = _IMMUTABLE_EMPTY_TUPLE + metadata = _compression.augment_metadata(metadata, compression) if not self._interceptors: call = UnaryUnaryCall(request, _timeout_to_deadline(timeout), @@ -168,7 +181,7 @@ class UnaryStreamMultiCallable(_BaseMultiCallable): request: Any, *, timeout: Optional[float] = None, - metadata: Optional[MetadataType] = None, + metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, compression: Optional[grpc.Compression] = None @@ -192,11 +205,9 @@ class UnaryStreamMultiCallable(_BaseMultiCallable): A Call object instance which is an awaitable object. """ if compression: - raise NotImplementedError("TODO: compression not implemented yet") + metadata = _compression.augment_metadata(metadata, compression) deadline = _timeout_to_deadline(timeout) - if metadata is None: - metadata = _IMMUTABLE_EMPTY_TUPLE call = UnaryStreamCall(request, deadline, metadata, credentials, wait_for_ready, self._channel, self._method, @@ -212,7 +223,7 @@ class StreamUnaryMultiCallable(_BaseMultiCallable): def __call__(self, request_async_iterator: Optional[AsyncIterable[Any]] = None, timeout: Optional[float] = None, - metadata: Optional[MetadataType] = None, + metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, compression: Optional[grpc.Compression] = None @@ -241,11 +252,9 @@ class StreamUnaryMultiCallable(_BaseMultiCallable): metadata, status code, and details. """ if compression: - raise NotImplementedError("TODO: compression not implemented yet") + metadata = _compression.augment_metadata(metadata, compression) deadline = _timeout_to_deadline(timeout) - if metadata is None: - metadata = _IMMUTABLE_EMPTY_TUPLE call = StreamUnaryCall(request_async_iterator, deadline, metadata, credentials, wait_for_ready, self._channel, @@ -261,7 +270,7 @@ class StreamStreamMultiCallable(_BaseMultiCallable): def __call__(self, request_async_iterator: Optional[AsyncIterable[Any]] = None, timeout: Optional[float] = None, - metadata: Optional[MetadataType] = None, + metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, compression: Optional[grpc.Compression] = None @@ -290,11 +299,9 @@ class StreamStreamMultiCallable(_BaseMultiCallable): metadata, status code, and details. """ if compression: - raise NotImplementedError("TODO: compression not implemented yet") + metadata = _compression.augment_metadata(metadata, compression) deadline = _timeout_to_deadline(timeout) - if metadata is None: - metadata = _IMMUTABLE_EMPTY_TUPLE call = StreamStreamCall(request_async_iterator, deadline, metadata, credentials, wait_for_ready, self._channel, @@ -314,7 +321,7 @@ class Channel: _unary_unary_interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] _ongoing_calls: _OngoingCalls - def __init__(self, target: Text, options: Optional[ChannelArgumentType], + def __init__(self, target: str, options: ChannelArgumentType, credentials: Optional[grpc.ChannelCredentials], compression: Optional[grpc.Compression], interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]]): @@ -329,10 +336,6 @@ class Channel: interceptors: An optional list of interceptors that would be used for intercepting any RPC executed with that channel. """ - - if compression: - raise NotImplementedError("TODO: compression not implemented yet") - if interceptors is None: self._unary_unary_interceptors = None else: @@ -352,8 +355,10 @@ class Channel: .format(invalid_interceptors)) self._loop = asyncio.get_event_loop() - self._channel = cygrpc.AioChannel(_common.encode(target), options, - credentials, self._loop) + self._channel = cygrpc.AioChannel( + _common.encode(target), + _augment_channel_arguments(options, compression), credentials, + self._loop) self._ongoing_calls = _OngoingCalls() async def __aenter__(self): @@ -456,9 +461,16 @@ class Channel: assert await self._channel.watch_connectivity_state( last_observed_state.value[0], None) + async def channel_ready(self) -> None: + """Creates a coroutine that ends when a Channel is ready.""" + state = self.get_state(try_to_connect=True) + while state != grpc.ChannelConnectivity.READY: + await self.wait_for_state_change(state) + state = self.get_state(try_to_connect=True) + def unary_unary( self, - method: Text, + method: str, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None ) -> UnaryUnaryMultiCallable: @@ -484,7 +496,7 @@ class Channel: def unary_stream( self, - method: Text, + method: str, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None ) -> UnaryStreamMultiCallable: @@ -495,7 +507,7 @@ class Channel: def stream_unary( self, - method: Text, + method: str, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None ) -> StreamUnaryMultiCallable: @@ -506,7 +518,7 @@ class Channel: def stream_stream( self, - method: Text, + method: str, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None ) -> StreamStreamMultiCallable: diff --git a/src/python/grpcio/grpc/experimental/aio/_interceptor.py b/src/python/grpcio/grpc/experimental/aio/_interceptor.py index 977fad71a4e..aca93fd468c 100644 --- a/src/python/grpcio/grpc/experimental/aio/_interceptor.py +++ b/src/python/grpcio/grpc/experimental/aio/_interceptor.py @@ -16,7 +16,7 @@ import asyncio import collections import functools from abc import ABCMeta, abstractmethod -from typing import Callable, Optional, Iterator, Sequence, Text, Union +from typing import Callable, Optional, Iterator, Sequence, Union import grpc from grpc._cython import cygrpc @@ -36,7 +36,7 @@ class ClientCallDetails( ('method', 'timeout', 'metadata', 'credentials', 'wait_for_ready')), grpc.ClientCallDetails): - method: Text + method: str timeout: Optional[float] metadata: Optional[MetadataType] credentials: Optional[grpc.CallCredentials] diff --git a/src/python/grpcio/grpc/experimental/aio/_server.py b/src/python/grpcio/grpc/experimental/aio/_server.py index ed93aad1c76..13ff381af09 100644 --- a/src/python/grpcio/grpc/experimental/aio/_server.py +++ b/src/python/grpcio/grpc/experimental/aio/_server.py @@ -13,39 +13,52 @@ # limitations under the License. """Server-side implementation of gRPC Asyncio Python.""" -from typing import Text, Optional import asyncio +from concurrent.futures import Executor +from typing import Any, Optional, Sequence + import grpc -from grpc import _common +from grpc import _common, _compression from grpc._cython import cygrpc +from ._typing import ChannelArgumentType + + +def _augment_channel_arguments(base_options: ChannelArgumentType, + compression: Optional[grpc.Compression]): + compression_option = _compression.create_channel_option(compression) + return tuple(base_options) + compression_option + class Server: """Serves RPCs.""" - def __init__(self, thread_pool, generic_handlers, interceptors, options, - maximum_concurrent_rpcs, compression): + def __init__(self, thread_pool: Optional[Executor], + generic_handlers: Optional[Sequence[grpc.GenericRpcHandler]], + interceptors: Optional[Sequence[Any]], + options: ChannelArgumentType, + maximum_concurrent_rpcs: Optional[int], + compression: Optional[grpc.Compression]): self._loop = asyncio.get_event_loop() - self._server = cygrpc.AioServer(self._loop, thread_pool, - generic_handlers, interceptors, options, - maximum_concurrent_rpcs, compression) + self._server = cygrpc.AioServer( + self._loop, thread_pool, generic_handlers, interceptors, + _augment_channel_arguments(options, compression), + maximum_concurrent_rpcs) def add_generic_rpc_handlers( self, - generic_rpc_handlers, - # generic_rpc_handlers: Iterable[grpc.GenericRpcHandlers] - ) -> None: + generic_rpc_handlers: Sequence[grpc.GenericRpcHandler]) -> None: """Registers GenericRpcHandlers with this Server. This method is only safe to call before the server is started. Args: - generic_rpc_handlers: An iterable of GenericRpcHandlers that will be + generic_rpc_handlers: A sequence of GenericRpcHandlers that will be used to service RPCs. """ self._server.add_generic_rpc_handlers(generic_rpc_handlers) - def add_insecure_port(self, address: Text) -> int: + def add_insecure_port(self, address: str) -> int: """Opens an insecure port for accepting RPCs. This method may only be called before starting the server. @@ -59,7 +72,7 @@ class Server: """ return self._server.add_insecure_port(_common.encode(address)) - def add_secure_port(self, address: Text, + def add_secure_port(self, address: str, server_credentials: grpc.ServerCredentials) -> int: """Opens a secure port for accepting RPCs. @@ -141,12 +154,12 @@ class Server: self._loop.create_task(self._server.shutdown(None)) -def server(migration_thread_pool=None, - handlers=None, - interceptors=None, - options=None, - maximum_concurrent_rpcs=None, - compression=None): +def server(migration_thread_pool: Optional[Executor] = None, + handlers: Optional[Sequence[grpc.GenericRpcHandler]] = None, + interceptors: Optional[Sequence[Any]] = None, + options: Optional[ChannelArgumentType] = None, + maximum_concurrent_rpcs: Optional[int] = None, + compression: Optional[grpc.Compression] = None): """Creates a Server with which RPCs can be serviced. Args: @@ -166,7 +179,8 @@ def server(migration_thread_pool=None, indicate no limit. compression: An element of grpc.compression, e.g. grpc.compression.Gzip. This compression algorithm will be used for the - lifetime of the server unless overridden. This is an EXPERIMENTAL option. + lifetime of the server unless overridden by set_compression. This is an + EXPERIMENTAL option. Returns: A Server object. diff --git a/src/python/grpcio/grpc/experimental/aio/_typing.py b/src/python/grpcio/grpc/experimental/aio/_typing.py index 15583754a63..ccd2f529936 100644 --- a/src/python/grpcio/grpc/experimental/aio/_typing.py +++ b/src/python/grpcio/grpc/experimental/aio/_typing.py @@ -13,15 +13,15 @@ # limitations under the License. """Common types for gRPC Async API""" -from typing import Any, AnyStr, Callable, Sequence, Text, Tuple, TypeVar +from typing import Any, AnyStr, Callable, Sequence, Tuple, TypeVar from grpc._cython.cygrpc import EOF RequestType = TypeVar('RequestType') ResponseType = TypeVar('ResponseType') SerializingFunction = Callable[[Any], bytes] DeserializingFunction = Callable[[bytes], Any] -MetadatumType = Tuple[Text, AnyStr] +MetadatumType = Tuple[str, AnyStr] MetadataType = Sequence[MetadatumType] -ChannelArgumentType = Sequence[Tuple[Text, Any]] +ChannelArgumentType = Sequence[Tuple[str, Any]] EOFType = type(EOF) DoneCallbackType = Callable[[Any], None] diff --git a/src/python/grpcio_tests/tests/stress/BUILD.bazel b/src/python/grpcio_tests/tests/stress/BUILD.bazel index b8af844373a..01c1509c2a1 100644 --- a/src/python/grpcio_tests/tests/stress/BUILD.bazel +++ b/src/python/grpcio_tests/tests/stress/BUILD.bazel @@ -1,3 +1,4 @@ +load("@rules_proto//proto:defs.bzl", "proto_library") load("@com_github_grpc_grpc//bazel:python_rules.bzl", "py_grpc_library", "py_proto_library") proto_library( diff --git a/src/python/grpcio_tests/tests_aio/tests.json b/src/python/grpcio_tests/tests_aio/tests.json index 4e787ec5c6d..d3765c7a531 100644 --- a/src/python/grpcio_tests/tests_aio/tests.json +++ b/src/python/grpcio_tests/tests_aio/tests.json @@ -9,9 +9,11 @@ "unit.call_test.TestUnaryStreamCall", "unit.call_test.TestUnaryUnaryCall", "unit.channel_argument_test.TestChannelArgument", + "unit.channel_ready_test.TestChannelReady", "unit.channel_test.TestChannel", "unit.close_channel_test.TestCloseChannel", "unit.close_channel_test.TestOngoingCalls", + "unit.compression_test.TestCompression", "unit.connectivity_test.TestConnectivityState", "unit.done_callback_test.TestDoneCallback", "unit.init_test.TestInsecureChannel", diff --git a/src/python/grpcio_tests/tests_aio/unit/channel_ready_test.py b/src/python/grpcio_tests/tests_aio/unit/channel_ready_test.py new file mode 100644 index 00000000000..37bdbf8b755 --- /dev/null +++ b/src/python/grpcio_tests/tests_aio/unit/channel_ready_test.py @@ -0,0 +1,67 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing the channel_ready function.""" + +import asyncio +import gc +import logging +import time +import unittest + +import grpc +from grpc.experimental import aio + +from tests.unit.framework.common import get_socket, test_constants +from tests_aio.unit import _common +from tests_aio.unit._test_base import AioTestBase +from tests_aio.unit._test_server import start_test_server + + +class TestChannelReady(AioTestBase): + + async def setUp(self): + address, self._port, self._socket = get_socket(listen=False) + self._channel = aio.insecure_channel(f"{address}:{self._port}") + self._socket.close() + + async def tearDown(self): + await self._channel.close() + + async def test_channel_ready_success(self): + # Start `channel_ready` as another Task + channel_ready_task = self.loop.create_task( + self._channel.channel_ready()) + + # Wait for TRANSIENT_FAILURE + await _common.block_until_certain_state( + self._channel, grpc.ChannelConnectivity.TRANSIENT_FAILURE) + + try: + # Start the server + _, server = await start_test_server(port=self._port) + + # The RPC should recover itself + await channel_ready_task + finally: + await server.stop(None) + + async def test_channel_ready_blocked(self): + with self.assertRaises(asyncio.TimeoutError): + await asyncio.wait_for(self._channel.channel_ready(), + test_constants.SHORT_TIMEOUT) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/compression_test.py b/src/python/grpcio_tests/tests_aio/unit/compression_test.py new file mode 100644 index 00000000000..9d93885ea23 --- /dev/null +++ b/src/python/grpcio_tests/tests_aio/unit/compression_test.py @@ -0,0 +1,196 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests behavior around the compression mechanism.""" + +import asyncio +import logging +import platform +import random +import unittest + +import grpc +from grpc.experimental import aio + +from tests_aio.unit._test_base import AioTestBase +from tests_aio.unit import _common + +_GZIP_CHANNEL_ARGUMENT = ('grpc.default_compression_algorithm', 2) +_GZIP_DISABLED_CHANNEL_ARGUMENT = ('grpc.compression_enabled_algorithms_bitset', + 3) +_DEFLATE_DISABLED_CHANNEL_ARGUMENT = ( + 'grpc.compression_enabled_algorithms_bitset', 5) + +_TEST_UNARY_UNARY = '/test/TestUnaryUnary' +_TEST_SET_COMPRESSION = '/test/TestSetCompression' +_TEST_DISABLE_COMPRESSION_UNARY = '/test/TestDisableCompressionUnary' +_TEST_DISABLE_COMPRESSION_STREAM = '/test/TestDisableCompressionStream' + +_REQUEST = b'\x01' * 100 +_RESPONSE = b'\x02' * 100 + + +async def _test_unary_unary(unused_request, unused_context): + return _RESPONSE + + +async def _test_set_compression(unused_request_iterator, context): + assert _REQUEST == await context.read() + context.set_compression(grpc.Compression.Deflate) + await context.write(_RESPONSE) + try: + context.set_compression(grpc.Compression.Deflate) + except RuntimeError: + # NOTE(lidiz) Testing if the servicer context raises exception when + # the set_compression method is called after initial_metadata sent. + # After the initial_metadata sent, the server-side has no control over + # which compression algorithm it should use. + pass + else: + raise ValueError( + 'Expecting exceptions if set_compression is not effective') + + +async def _test_disable_compression_unary(request, context): + assert _REQUEST == request + context.set_compression(grpc.Compression.Deflate) + context.disable_next_message_compression() + return _RESPONSE + + +async def _test_disable_compression_stream(unused_request_iterator, context): + assert _REQUEST == await context.read() + context.set_compression(grpc.Compression.Deflate) + await context.write(_RESPONSE) + context.disable_next_message_compression() + await context.write(_RESPONSE) + await context.write(_RESPONSE) + + +_ROUTING_TABLE = { + _TEST_UNARY_UNARY: + grpc.unary_unary_rpc_method_handler(_test_unary_unary), + _TEST_SET_COMPRESSION: + grpc.stream_stream_rpc_method_handler(_test_set_compression), + _TEST_DISABLE_COMPRESSION_UNARY: + grpc.unary_unary_rpc_method_handler(_test_disable_compression_unary), + _TEST_DISABLE_COMPRESSION_STREAM: + grpc.stream_stream_rpc_method_handler(_test_disable_compression_stream), +} + + +class _GenericHandler(grpc.GenericRpcHandler): + + def service(self, handler_call_details): + return _ROUTING_TABLE.get(handler_call_details.method) + + +async def _start_test_server(options=None): + server = aio.server(options=options) + port = server.add_insecure_port('[::]:0') + server.add_generic_rpc_handlers((_GenericHandler(),)) + await server.start() + return f'localhost:{port}', server + + +class TestCompression(AioTestBase): + + async def setUp(self): + server_options = (_GZIP_DISABLED_CHANNEL_ARGUMENT,) + self._address, self._server = await _start_test_server(server_options) + self._channel = aio.insecure_channel(self._address) + + async def tearDown(self): + await self._channel.close() + await self._server.stop(None) + + async def test_channel_level_compression_baned_compression(self): + # GZIP is disabled, this call should fail + async with aio.insecure_channel( + self._address, compression=grpc.Compression.Gzip) as channel: + multicallable = channel.unary_unary(_TEST_UNARY_UNARY) + call = multicallable(_REQUEST) + with self.assertRaises(aio.AioRpcError) as exception_context: + await call + rpc_error = exception_context.exception + self.assertEqual(grpc.StatusCode.UNIMPLEMENTED, rpc_error.code()) + + async def test_channel_level_compression_allowed_compression(self): + # Deflate is allowed, this call should succeed + async with aio.insecure_channel( + self._address, compression=grpc.Compression.Deflate) as channel: + multicallable = channel.unary_unary(_TEST_UNARY_UNARY) + call = multicallable(_REQUEST) + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + async def test_client_call_level_compression_baned_compression(self): + multicallable = self._channel.unary_unary(_TEST_UNARY_UNARY) + + # GZIP is disabled, this call should fail + call = multicallable(_REQUEST, compression=grpc.Compression.Gzip) + with self.assertRaises(aio.AioRpcError) as exception_context: + await call + rpc_error = exception_context.exception + self.assertEqual(grpc.StatusCode.UNIMPLEMENTED, rpc_error.code()) + + async def test_client_call_level_compression_allowed_compression(self): + multicallable = self._channel.unary_unary(_TEST_UNARY_UNARY) + + # Deflate is allowed, this call should succeed + call = multicallable(_REQUEST, compression=grpc.Compression.Deflate) + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + async def test_server_call_level_compression(self): + multicallable = self._channel.stream_stream(_TEST_SET_COMPRESSION) + call = multicallable() + await call.write(_REQUEST) + await call.done_writing() + self.assertEqual(_RESPONSE, await call.read()) + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + async def test_server_disable_compression_unary(self): + multicallable = self._channel.unary_unary( + _TEST_DISABLE_COMPRESSION_UNARY) + call = multicallable(_REQUEST) + self.assertEqual(_RESPONSE, await call) + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + async def test_server_disable_compression_stream(self): + multicallable = self._channel.stream_stream( + _TEST_DISABLE_COMPRESSION_STREAM) + call = multicallable() + await call.write(_REQUEST) + await call.done_writing() + self.assertEqual(_RESPONSE, await call.read()) + self.assertEqual(_RESPONSE, await call.read()) + self.assertEqual(_RESPONSE, await call.read()) + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + async def test_server_default_compression_algorithm(self): + server = aio.server(compression=grpc.Compression.Deflate) + port = server.add_insecure_port('[::]:0') + server.add_generic_rpc_handlers((_GenericHandler(),)) + await server.start() + + async with aio.insecure_channel(f'localhost:{port}') as channel: + multicallable = channel.unary_unary(_TEST_UNARY_UNARY) + call = multicallable(_REQUEST) + self.assertEqual(_RESPONSE, await call) + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + await server.stop(None) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/connectivity_test.py b/src/python/grpcio_tests/tests_aio/unit/connectivity_test.py index 7acf53b95c7..7f98329070b 100644 --- a/src/python/grpcio_tests/tests_aio/unit/connectivity_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/connectivity_test.py @@ -102,7 +102,7 @@ class TestConnectivityState(AioTestBase): # It can raise exceptions since it is an usage error, but it should not # segfault or abort. - with self.assertRaises(RuntimeError): + with self.assertRaises(aio.UsageError): await channel.wait_for_state_change( grpc.ChannelConnectivity.SHUTDOWN) diff --git a/src/python/grpcio_tests/tests_aio/unit/server_test.py b/src/python/grpcio_tests/tests_aio/unit/server_test.py index 39288d90777..70240fefee1 100644 --- a/src/python/grpcio_tests/tests_aio/unit/server_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/server_test.py @@ -231,14 +231,10 @@ class TestServer(AioTestBase): # Uses reader API self.assertEqual(_RESPONSE, await call.read()) - # Uses async generator API - response_cnt = 0 - async for response in call: - response_cnt += 1 - self.assertEqual(_RESPONSE, response) - - self.assertEqual(_NUM_STREAM_RESPONSES - 1, response_cnt) - self.assertEqual(await call.code(), grpc.StatusCode.OK) + # Uses async generator API, mixed! + with self.assertRaises(aio.UsageError): + async for response in call: + self.assertEqual(_RESPONSE, response) async def test_stream_unary_async_generator(self): stream_unary_call = self._channel.stream_unary(_STREAM_UNARY_ASYNC_GEN) diff --git a/test/cpp/end2end/xds_end2end_test.cc b/test/cpp/end2end/xds_end2end_test.cc index f86d032a878..b3b0867fc1b 100644 --- a/test/cpp/end2end/xds_end2end_test.cc +++ b/test/cpp/end2end/xds_end2end_test.cc @@ -991,7 +991,8 @@ class XdsEnd2endTest : public ::testing::TestWithParam { } std::tuple WaitForAllBackends(size_t start_index = 0, - size_t stop_index = 0) { + size_t stop_index = 0, + bool reset_counters = true) { int num_ok = 0; int num_failure = 0; int num_drops = 0; @@ -999,7 +1000,7 @@ class XdsEnd2endTest : public ::testing::TestWithParam { while (!SeenAllBackends(start_index, stop_index)) { SendRpcAndCount(&num_total, &num_ok, &num_failure, &num_drops); } - ResetBackendCounters(); + if (reset_counters) ResetBackendCounters(); gpr_log(GPR_INFO, "Performed %d warm up requests against the backends. " "%d succeeded, %d failed, %d dropped.", @@ -2202,6 +2203,41 @@ TEST_P(FailoverTest, UpdatePriority) { EXPECT_EQ(2U, balancers_[0]->ads_service()->response_count()); } +// Moves all localities in the current priority to a higher priority. +TEST_P(FailoverTest, MoveAllLocalitiesInCurrentPriorityToHigherPriority) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // First update: + // - Priority 0 is locality 0, containing backend 0, which is down. + // - Priority 1 is locality 1, containing backends 1 and 2, which are up. + ShutdownBackend(0); + AdsServiceImpl::ResponseArgs args({ + {"locality0", GetBackendPorts(0, 1), kDefaultLocalityWeight, 0}, + {"locality1", GetBackendPorts(1, 3), kDefaultLocalityWeight, 1}, + }); + ScheduleResponseForBalancer(0, AdsServiceImpl::BuildResponse(args), 0); + // Second update: + // - Priority 0 contains both localities 0 and 1. + // - Priority 1 is not present. + // - We add backend 3 to locality 1, just so we have a way to know + // when the update has been seen by the client. + args = AdsServiceImpl::ResponseArgs({ + {"locality0", GetBackendPorts(0, 1), kDefaultLocalityWeight, 0}, + {"locality1", GetBackendPorts(1, 4), kDefaultLocalityWeight, 0}, + }); + ScheduleResponseForBalancer(0, AdsServiceImpl::BuildResponse(args), 1000); + // When we get the first update, all backends in priority 0 are down, + // so we will create priority 1. Backends 1 and 2 should have traffic, + // but backend 3 should not. + WaitForAllBackends(1, 3, false); + EXPECT_EQ(0UL, backends_[3]->backend_service()->request_count()); + // When backend 3 gets traffic, we know the second update has been seen. + WaitForBackend(3); + // The ADS service got a single request, and sent a single response. + EXPECT_EQ(1U, balancers_[0]->ads_service()->request_count()); + EXPECT_EQ(2U, balancers_[0]->ads_service()->response_count()); +} + using DropTest = BasicTest; // Tests that RPCs are dropped according to the drop config. diff --git a/third_party/py/python_configure.bzl b/third_party/py/python_configure.bzl index 34501c32121..d46fcb79c8a 100644 --- a/third_party/py/python_configure.bzl +++ b/third_party/py/python_configure.bzl @@ -14,9 +14,9 @@ _PYTHON3_BIN_PATH = "PYTHON3_BIN_PATH" _PYTHON3_LIB_PATH = "PYTHON3_LIB_PATH" _HEADERS_HELP = ( - "Are Python headers installed? Try installing python-dev or " + - "python3-dev on Debian-based systems. Try python-devel or python3-devel " + - "on Redhat-based systems." + "Are Python headers installed? Try installing python-dev or " + + "python3-dev on Debian-based systems. Try python-devel or python3-devel " + + "on Redhat-based systems." ) def _tpl(repository_ctx, tpl, substitutions = {}, out = None): @@ -246,11 +246,11 @@ def _get_python_include(repository_ctx, python_bin): _execute( repository_ctx, [ - python_bin, - "-c", - "import os;" + - "main_header = os.path.join('{}', 'Python.h');".format(include_path) + - "assert os.path.exists(main_header), main_header + ' does not exist.'" + python_bin, + "-c", + "import os;" + + "main_header = os.path.join('{}', 'Python.h');".format(include_path) + + "assert os.path.exists(main_header), main_header + ' does not exist.'", ], error_msg = "Unable to find Python headers for {}".format(python_bin), error_details = _HEADERS_HELP, diff --git a/third_party/upb/BUILD b/third_party/upb/BUILD index fa2ad904f4b..351571c4d90 100644 --- a/third_party/upb/BUILD +++ b/third_party/upb/BUILD @@ -1,3 +1,4 @@ +load("@rules_proto//proto:defs.bzl", "proto_library") load( "//bazel:build_defs.bzl", "generated_file_staleness_test", @@ -56,13 +57,13 @@ config_setting( cc_library( name = "port", + srcs = [ + "upb/port.c", + ], textual_hdrs = [ "upb/port_def.inc", "upb/port_undef.inc", ], - srcs = [ - "upb/port.c", - ], ) cc_library( @@ -159,8 +160,8 @@ cc_library( cc_library( name = "legacy_msg_reflection", srcs = [ - "upb/msg.h", "upb/legacy_msg_reflection.c", + "upb/msg.h", ], hdrs = ["upb/legacy_msg_reflection.h"], copts = select({ @@ -190,8 +191,8 @@ cc_library( "//conditions:default": COPTS, }), deps = [ - ":reflection", ":port", + ":reflection", ":table", ":upb", ], @@ -220,8 +221,8 @@ cc_library( deps = [ ":descriptor_upbproto", ":handlers", - ":reflection", ":port", + ":reflection", ":table", ":upb", ], diff --git a/third_party/upb/bazel/upb_proto_library.bzl b/third_party/upb/bazel/upb_proto_library.bzl index bea611776c4..45b346d8386 100644 --- a/third_party/upb/bazel/upb_proto_library.bzl +++ b/third_party/upb/bazel/upb_proto_library.bzl @@ -8,6 +8,7 @@ load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain") # copybara:strip_for_google3_begin load("@bazel_skylib//lib:versions.bzl", "versions") +load("@rules_proto//proto:defs.bzl", "ProtoInfo") load("@upb_bazel_version//:bazel_version.bzl", "bazel_version") # copybara:strip_end @@ -22,6 +23,7 @@ def _get_real_short_path(file): if short_path.startswith("../"): second_slash = short_path.index("/", 3) short_path = short_path[second_slash + 1:] + # Sometimes it has another few prefixes like: # _virtual_imports/any_proto/google/protobuf/any.proto # We want just google/protobuf/any.proto. diff --git a/third_party/upb/examples/bazel/BUILD b/third_party/upb/examples/bazel/BUILD index b1c60db0469..9fb5c1f4a64 100644 --- a/third_party/upb/examples/bazel/BUILD +++ b/third_party/upb/examples/bazel/BUILD @@ -1,3 +1,4 @@ +load("@rules_proto//proto:defs.bzl", "proto_library") load("@upb//bazel:upb_proto_library.bzl", "upb_proto_library") proto_library( diff --git a/tools/interop_matrix/client_matrix.py b/tools/interop_matrix/client_matrix.py index b5e4229b259..abaac04b160 100644 --- a/tools/interop_matrix/client_matrix.py +++ b/tools/interop_matrix/client_matrix.py @@ -152,6 +152,7 @@ LANG_RELEASE_MATRIX = { ('v1.24.0', ReleaseInfo(runtimes=['go1.11'])), ('v1.25.0', ReleaseInfo(runtimes=['go1.11'])), ('v1.26.0', ReleaseInfo(runtimes=['go1.11'])), + ('v1.27.1', ReleaseInfo(runtimes=['go1.11'])), ]), 'java': OrderedDict([ diff --git a/tools/run_tests/helper_scripts/prep_xds.sh b/tools/run_tests/helper_scripts/prep_xds.sh new file mode 100755 index 00000000000..68128b18068 --- /dev/null +++ b/tools/run_tests/helper_scripts/prep_xds.sh @@ -0,0 +1,36 @@ +#!/bin/bash +# Copyright 2020 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -ex + +# change to grpc repo root +cd "$(dirname "$0")/../../.." + +sudo apt-get install -y python3-pip +sudo python3 -m pip install grpcio grpcio-tools google-api-python-client google-auth-httplib2 + +# Prepare generated Python code. +TOOLS_DIR=tools/run_tests +PROTO_SOURCE_DIR=src/proto/grpc/testing +PROTO_DEST_DIR=${TOOLS_DIR}/${PROTO_SOURCE_DIR} +mkdir -p ${PROTO_DEST_DIR} + +python3 -m grpc_tools.protoc \ + --proto_path=. \ + --python_out=${TOOLS_DIR} \ + --grpc_python_out=${TOOLS_DIR} \ + ${PROTO_SOURCE_DIR}/test.proto \ + ${PROTO_SOURCE_DIR}/messages.proto \ + ${PROTO_SOURCE_DIR}/empty.proto diff --git a/tools/run_tests/run_xds_tests.py b/tools/run_tests/run_xds_tests.py new file mode 100755 index 00000000000..1b1435a93b1 --- /dev/null +++ b/tools/run_tests/run_xds_tests.py @@ -0,0 +1,595 @@ +#!/usr/bin/env python +# Copyright 2020 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Run xDS integration tests on GCP using Traffic Director.""" + +import argparse +import googleapiclient.discovery +import grpc +import logging +import os +import shlex +import socket +import subprocess +import sys +import tempfile +import time + +from src.proto.grpc.testing import messages_pb2 +from src.proto.grpc.testing import test_pb2_grpc + +logger = logging.getLogger(__name__) +console_handler = logging.StreamHandler() +logger.addHandler(console_handler) + +argp = argparse.ArgumentParser(description='Run xDS interop tests on GCP') +argp.add_argument('--project_id', help='GCP project id') +argp.add_argument( + '--gcp_suffix', + default='', + help='Optional suffix for all generated GCP resource names. Useful to ensure ' + 'distinct names across test runs.') +argp.add_argument('--test_case', + default=None, + choices=['all', 'ping_pong', 'round_robin']) +argp.add_argument( + '--client_cmd', + default=None, + help='Command to launch xDS test client. This script will fill in ' + '{service_host}, {service_port},{stats_port} and {qps} parameters using ' + 'str.format(), and generate the GRPC_XDS_BOOTSTRAP file.') +argp.add_argument('--zone', default='us-central1-a') +argp.add_argument('--qps', default=10, help='Client QPS') +argp.add_argument( + '--wait_for_backend_sec', + default=900, + help='Time limit for waiting for created backend services to report healthy ' + 'when launching test suite') +argp.add_argument( + '--keep_gcp_resources', + default=False, + action='store_true', + help= + 'Leave GCP VMs and configuration running after test. Default behavior is ' + 'to delete when tests complete.') +argp.add_argument( + '--tolerate_gcp_errors', + default=False, + action='store_true', + help= + 'Continue with test even when an error occurs during setup. Intended for ' + 'manual testing, where attempts to recreate any GCP resources already ' + 'existing will result in an error') +argp.add_argument('--verbose', + help='verbose log output', + default=False, + action="store_true") +args = argp.parse_args() + +if args.verbose: + logger.setLevel(logging.DEBUG) + +PROJECT_ID = args.project_id +ZONE = args.zone +QPS = args.qps +TEST_CASE = args.test_case +CLIENT_CMD = args.client_cmd +WAIT_FOR_BACKEND_SEC = args.wait_for_backend_sec +TEMPLATE_NAME = 'test-template' + args.gcp_suffix +INSTANCE_GROUP_NAME = 'test-ig' + args.gcp_suffix +HEALTH_CHECK_NAME = 'test-hc' + args.gcp_suffix +FIREWALL_RULE_NAME = 'test-fw-rule' + args.gcp_suffix +BACKEND_SERVICE_NAME = 'test-backend-service' + args.gcp_suffix +URL_MAP_NAME = 'test-map' + args.gcp_suffix +SERVICE_HOST = 'grpc-test' + args.gcp_suffix +TARGET_PROXY_NAME = 'test-target-proxy' + args.gcp_suffix +FORWARDING_RULE_NAME = 'test-forwarding-rule' + args.gcp_suffix +KEEP_GCP_RESOURCES = args.keep_gcp_resources +TOLERATE_GCP_ERRORS = args.tolerate_gcp_errors +SERVICE_PORT = 55551 +STATS_PORT = 55552 +INSTANCE_GROUP_SIZE = 2 +WAIT_FOR_OPERATION_SEC = 60 +NUM_TEST_RPCS = 10 * QPS +WAIT_FOR_STATS_SEC = 30 +BOOTSTRAP_TEMPLATE = """ +{{ + "node": {{ + "id": "{node_id}" + }}, + "xds_servers": [{{ + "server_uri": "trafficdirector.googleapis.com:443", + "channel_creds": [ + {{ + "type": "google_default", + "config": {{}} + }} + ] + }}] +}}""" + + +def get_client_stats(num_rpcs, timeout_sec): + with grpc.insecure_channel('localhost:%d' % STATS_PORT) as channel: + stub = test_pb2_grpc.LoadBalancerStatsServiceStub(channel) + request = messages_pb2.LoadBalancerStatsRequest() + request.num_rpcs = num_rpcs + request.timeout_sec = timeout_sec + try: + response = stub.GetClientStats(request, wait_for_ready=True) + logger.debug('Invoked GetClientStats RPC: %s', response) + return response + except grpc.RpcError as rpc_error: + raise Exception('GetClientStats RPC failed') + + +def wait_until_only_given_backends_receive_load(backends, timeout_sec): + start_time = time.time() + error_msg = None + while time.time() - start_time <= timeout_sec: + error_msg = None + stats = get_client_stats(max(len(backends), 1), timeout_sec) + rpcs_by_peer = stats.rpcs_by_peer + for backend in backends: + if backend not in rpcs_by_peer: + error_msg = 'Backend %s did not receive load' % backend + break + if not error_msg and len(rpcs_by_peer) > len(backends): + error_msg = 'Unexpected backend received load: %s' % rpcs_by_peer + if not error_msg: + return + raise Exception(error_msg) + + +def test_ping_pong(backends, num_rpcs, stats_timeout_sec): + start_time = time.time() + error_msg = None + while time.time() - start_time <= stats_timeout_sec: + error_msg = None + stats = get_client_stats(num_rpcs, stats_timeout_sec) + rpcs_by_peer = stats.rpcs_by_peer + for backend in backends: + if backend not in rpcs_by_peer: + error_msg = 'Backend %s did not receive load' % backend + break + if not error_msg and len(rpcs_by_peer) > len(backends): + error_msg = 'Unexpected backend received load: %s' % rpcs_by_peer + if not error_msg: + return + raise Exception(error_msg) + + +def test_round_robin(backends, num_rpcs, stats_timeout_sec): + threshold = 1 + wait_until_only_given_backends_receive_load(backends, stats_timeout_sec) + stats = get_client_stats(num_rpcs, stats_timeout_sec) + requests_received = [stats.rpcs_by_peer[x] for x in stats.rpcs_by_peer] + total_requests_received = sum( + [stats.rpcs_by_peer[x] for x in stats.rpcs_by_peer]) + if total_requests_received != num_rpcs: + raise Exception('Unexpected RPC failures', stats) + expected_requests = total_requests_received / len(backends) + for backend in backends: + if abs(stats.rpcs_by_peer[backend] - expected_requests) > threshold: + raise Exception( + 'RPC peer distribution differs from expected by more than %d for backend %s (%s)', + threshold, backend, stats) + + +def create_instance_template(compute, name, grpc_port, project): + config = { + 'name': name, + 'properties': { + 'tags': { + 'items': ['grpc-allow-healthcheck'] + }, + 'machineType': 'e2-standard-2', + 'serviceAccounts': [{ + 'email': 'default', + 'scopes': ['https://www.googleapis.com/auth/cloud-platform',] + }], + 'networkInterfaces': [{ + 'accessConfigs': [{ + 'type': 'ONE_TO_ONE_NAT' + }], + 'network': 'global/networks/default' + }], + 'disks': [{ + 'boot': True, + 'initializeParams': { + 'sourceImage': + 'projects/debian-cloud/global/images/family/debian-9' + } + }], + 'metadata': { + 'items': [{ + 'key': + 'startup-script', + 'value': + """#!/bin/bash + +sudo apt update +sudo apt install -y git default-jdk +mkdir java_server +pushd java_server +git clone https://github.com/grpc/grpc-java.git +pushd grpc-java +pushd interop-testing +../gradlew installDist -x test -PskipCodegen=true -PskipAndroid=true + +nohup build/install/grpc-interop-testing/bin/xds-test-server --port=%d 1>/dev/null &""" + % grpc_port + }] + } + } + } + + result = compute.instanceTemplates().insert(project=project, + body=config).execute() + wait_for_global_operation(compute, project, result['name']) + return result['targetLink'] + + +def create_instance_group(compute, name, size, grpc_port, template_url, project, + zone): + config = { + 'name': name, + 'instanceTemplate': template_url, + 'targetSize': size, + 'namedPorts': [{ + 'name': 'grpc', + 'port': grpc_port + }] + } + + result = compute.instanceGroupManagers().insert(project=project, + zone=zone, + body=config).execute() + wait_for_zone_operation(compute, project, zone, result['name']) + result = compute.instanceGroupManagers().get( + project=PROJECT_ID, zone=ZONE, instanceGroupManager=name).execute() + return result['instanceGroup'] + + +def create_health_check(compute, name, project): + config = { + 'name': name, + 'type': 'TCP', + 'tcpHealthCheck': { + 'portName': 'grpc' + } + } + result = compute.healthChecks().insert(project=project, + body=config).execute() + wait_for_global_operation(compute, project, result['name']) + return result['targetLink'] + + +def create_health_check_firewall_rule(compute, name, project): + config = { + 'name': name, + 'direction': 'INGRESS', + 'allowed': [{ + 'IPProtocol': 'tcp' + }], + 'sourceRanges': ['35.191.0.0/16', '130.211.0.0/22'], + 'targetTags': ['grpc-allow-healthcheck'], + } + result = compute.firewalls().insert(project=project, body=config).execute() + wait_for_global_operation(compute, project, result['name']) + + +def create_backend_service(compute, name, instance_group, health_check, + project): + config = { + 'name': name, + 'loadBalancingScheme': 'INTERNAL_SELF_MANAGED', + 'healthChecks': [health_check], + 'portName': 'grpc', + 'protocol': 'HTTP2', + 'backends': [{ + 'group': instance_group, + }] + } + result = compute.backendServices().insert(project=project, + body=config).execute() + wait_for_global_operation(compute, project, result['name']) + return result['targetLink'] + + +def create_url_map(compute, name, backend_service_url, host_name, project): + path_matcher_name = 'path-matcher' + config = { + 'name': name, + 'defaultService': backend_service_url, + 'pathMatchers': [{ + 'name': path_matcher_name, + 'defaultService': backend_service_url, + }], + 'hostRules': [{ + 'hosts': [host_name], + 'pathMatcher': path_matcher_name + }] + } + result = compute.urlMaps().insert(project=project, body=config).execute() + wait_for_global_operation(compute, project, result['name']) + return result['targetLink'] + + +def create_target_http_proxy(compute, name, url_map_url, project): + config = { + 'name': name, + 'url_map': url_map_url, + } + result = compute.targetHttpProxies().insert(project=project, + body=config).execute() + wait_for_global_operation(compute, project, result['name']) + return result['targetLink'] + + +def create_global_forwarding_rule(compute, name, grpc_port, + target_http_proxy_url, project): + config = { + 'name': name, + 'loadBalancingScheme': 'INTERNAL_SELF_MANAGED', + 'portRange': str(grpc_port), + 'IPAddress': '0.0.0.0', + 'target': target_http_proxy_url, + } + result = compute.globalForwardingRules().insert(project=project, + body=config).execute() + wait_for_global_operation(compute, project, result['name']) + + +def delete_global_forwarding_rule(compute, project, forwarding_rule): + try: + result = compute.globalForwardingRules().delete( + project=project, forwardingRule=forwarding_rule).execute() + wait_for_global_operation(compute, project, result['name']) + except googleapiclient.errors.HttpError as http_error: + logger.info('Delete failed: %s', http_error) + + +def delete_target_http_proxy(compute, project, target_http_proxy): + try: + result = compute.targetHttpProxies().delete( + project=project, targetHttpProxy=target_http_proxy).execute() + wait_for_global_operation(compute, project, result['name']) + except googleapiclient.errors.HttpError as http_error: + logger.info('Delete failed: %s', http_error) + + +def delete_url_map(compute, project, url_map): + try: + result = compute.urlMaps().delete(project=project, + urlMap=url_map).execute() + wait_for_global_operation(compute, project, result['name']) + except googleapiclient.errors.HttpError as http_error: + logger.info('Delete failed: %s', http_error) + + +def delete_backend_service(compute, project, backend_service): + try: + result = compute.backendServices().delete( + project=project, backendService=backend_service).execute() + wait_for_global_operation(compute, project, result['name']) + except googleapiclient.errors.HttpError as http_error: + logger.info('Delete failed: %s', http_error) + + +def delete_firewall(compute, project, firewall_rule): + try: + result = compute.firewalls().delete(project=project, + firewall=firewall_rule).execute() + wait_for_global_operation(compute, project, result['name']) + except googleapiclient.errors.HttpError as http_error: + logger.info('Delete failed: %s', http_error) + + +def delete_health_check(compute, project, health_check): + try: + result = compute.healthChecks().delete( + project=project, healthCheck=health_check).execute() + wait_for_global_operation(compute, project, result['name']) + except googleapiclient.errors.HttpError as http_error: + logger.info('Delete failed: %s', http_error) + + +def delete_instance_group(compute, project, zone, instance_group): + try: + result = compute.instanceGroupManagers().delete( + project=project, zone=zone, + instanceGroupManager=instance_group).execute() + timeout_sec = 180 # Deleting an instance group can be slow + wait_for_zone_operation(compute, + project, + ZONE, + result['name'], + timeout_sec=timeout_sec) + except googleapiclient.errors.HttpError as http_error: + logger.info('Delete failed: %s', http_error) + + +def delete_instance_template(compute, project, instance_template): + try: + result = compute.instanceTemplates().delete( + project=project, instanceTemplate=instance_template).execute() + wait_for_global_operation(compute, project, result['name']) + except googleapiclient.errors.HttpError as http_error: + logger.info('Delete failed: %s', http_error) + + +def wait_for_global_operation(compute, + project, + operation, + timeout_sec=WAIT_FOR_OPERATION_SEC): + start_time = time.time() + while time.time() - start_time <= timeout_sec: + result = compute.globalOperations().get(project=project, + operation=operation).execute() + if result['status'] == 'DONE': + if 'error' in result: + raise Exception(result['error']) + return + time.sleep(1) + raise Exception('Operation %s did not complete within %d', operation, + timeout_sec) + + +def wait_for_zone_operation(compute, + project, + zone, + operation, + timeout_sec=WAIT_FOR_OPERATION_SEC): + start_time = time.time() + while time.time() - start_time <= timeout_sec: + result = compute.zoneOperations().get(project=project, + zone=zone, + operation=operation).execute() + if result['status'] == 'DONE': + if 'error' in result: + raise Exception(result['error']) + return + time.sleep(1) + raise Exception('Operation %s did not complete within %d', operation, + timeout_sec) + + +def wait_for_healthy_backends(compute, project_id, backend_service, + instance_group_url, timeout_sec): + start_time = time.time() + config = {'group': instance_group_url} + while time.time() - start_time <= timeout_sec: + result = compute.backendServices().getHealth( + project=project_id, backendService=backend_service, + body=config).execute() + if 'healthStatus' in result: + healthy = True + for instance in result['healthStatus']: + if instance['healthState'] != 'HEALTHY': + healthy = False + break + if healthy: + return + time.sleep(1) + raise Exception('Not all backends became healthy within %d seconds: %s' % + (timeout_sec, result)) + + +def start_xds_client(): + cmd = CLIENT_CMD.format(service_host=SERVICE_HOST, + service_port=SERVICE_PORT, + stats_port=STATS_PORT, + qps=QPS) + bootstrap_path = None + with tempfile.NamedTemporaryFile(delete=False) as bootstrap_file: + bootstrap_file.write( + BOOTSTRAP_TEMPLATE.format( + node_id=socket.gethostname()).encode('utf-8')) + bootstrap_path = bootstrap_file.name + + client_process = subprocess.Popen(shlex.split(cmd), + env=dict( + os.environ, + GRPC_XDS_BOOTSTRAP=bootstrap_path)) + return client_process + + +compute = googleapiclient.discovery.build('compute', 'v1') +client_process = None + +try: + instance_group_url = None + try: + template_url = create_instance_template(compute, TEMPLATE_NAME, + SERVICE_PORT, PROJECT_ID) + instance_group_url = create_instance_group(compute, INSTANCE_GROUP_NAME, + INSTANCE_GROUP_SIZE, + SERVICE_PORT, template_url, + PROJECT_ID, ZONE) + health_check_url = create_health_check(compute, HEALTH_CHECK_NAME, + PROJECT_ID) + create_health_check_firewall_rule(compute, FIREWALL_RULE_NAME, + PROJECT_ID) + backend_service_url = create_backend_service(compute, + BACKEND_SERVICE_NAME, + instance_group_url, + health_check_url, + PROJECT_ID) + url_map_url = create_url_map(compute, URL_MAP_NAME, backend_service_url, + SERVICE_HOST, PROJECT_ID) + target_http_proxy_url = create_target_http_proxy( + compute, TARGET_PROXY_NAME, url_map_url, PROJECT_ID) + create_global_forwarding_rule(compute, FORWARDING_RULE_NAME, + SERVICE_PORT, target_http_proxy_url, + PROJECT_ID) + except googleapiclient.errors.HttpError as http_error: + if TOLERATE_GCP_ERRORS: + logger.warning( + 'Failed to set up backends: %s. Continuing since ' + '--tolerate_gcp_errors=true', http_error) + else: + raise http_error + + if instance_group_url is None: + # Look up the instance group URL, which may be unset if we are running + # with --tolerate_gcp_errors=true. + result = compute.instanceGroups().get( + project=PROJECT_ID, zone=ZONE, + instanceGroup=INSTANCE_GROUP_NAME).execute() + instance_group_url = result['selfLink'] + wait_for_healthy_backends(compute, PROJECT_ID, BACKEND_SERVICE_NAME, + instance_group_url, WAIT_FOR_BACKEND_SEC) + + backends = [] + result = compute.instanceGroups().listInstances( + project=PROJECT_ID, + zone=ZONE, + instanceGroup=INSTANCE_GROUP_NAME, + body={ + 'instanceState': 'ALL' + }).execute() + for item in result['items']: + # listInstances() returns the full URL of the instance, which ends with + # the instance name. compute.instances().get() requires using the + # instance name (not the full URL) to look up instance details, so we + # just extract the name manually. + instance_name = item['instance'].split('/')[-1] + backends.append(instance_name) + + client_process = start_xds_client() + + if TEST_CASE == 'all': + test_ping_pong(backends, NUM_TEST_RPCS, WAIT_FOR_STATS_SEC) + test_round_robin(backends, NUM_TEST_RPCS, WAIT_FOR_STATS_SEC) + elif TEST_CASE == 'ping_pong': + test_ping_pong(backends, NUM_TEST_RPCS, WAIT_FOR_STATS_SEC) + elif TEST_CASE == 'round_robin': + test_round_robin(backends, NUM_TEST_RPCS, WAIT_FOR_STATS_SEC) + else: + logger.error('Unknown test case: %s', TEST_CASE) + sys.exit(1) +finally: + if client_process: + client_process.terminate() + if not KEEP_GCP_RESOURCES: + logger.info('Cleaning up GCP resources. This may take some time.') + delete_global_forwarding_rule(compute, PROJECT_ID, FORWARDING_RULE_NAME) + delete_target_http_proxy(compute, PROJECT_ID, TARGET_PROXY_NAME) + delete_url_map(compute, PROJECT_ID, URL_MAP_NAME) + delete_backend_service(compute, PROJECT_ID, BACKEND_SERVICE_NAME) + delete_firewall(compute, PROJECT_ID, FIREWALL_RULE_NAME) + delete_health_check(compute, PROJECT_ID, HEALTH_CHECK_NAME) + delete_instance_group(compute, PROJECT_ID, ZONE, INSTANCE_GROUP_NAME) + delete_instance_template(compute, PROJECT_ID, TEMPLATE_NAME)