Merge pull request #5 from grpc/master

Sync with grpc repo
pull/22032/head
Zhanghui Mao 5 years ago committed by GitHub
commit 4cbe126830
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 28
      BUILDING.md
  2. 4
      bazel/cc_grpc_library.bzl
  3. 1
      bazel/generate_cc.bzl
  4. 1
      bazel/generate_objc.bzl
  5. 2
      bazel/protobuf.bzl
  6. 1
      bazel/python_rules.bzl
  7. 1
      bazel/test/python_test_repo/BUILD
  8. 10
      doc/environment_variables.md
  9. 5
      examples/BUILD
  10. 6
      examples/cpp/helloworld/README.md
  11. 29
      examples/cpp/helloworld/greeter_client.cc
  12. 1
      examples/python/cancellation/BUILD.bazel
  13. 1
      examples/python/multiprocessing/BUILD
  14. 4
      src/core/ext/filters/client_channel/lb_policy/xds/cds.cc
  15. 18
      src/core/ext/filters/client_channel/lb_policy/xds/xds.cc
  16. 361
      src/core/ext/filters/client_channel/xds/xds_api.cc
  17. 355
      src/core/ext/filters/client_channel/xds/xds_api.h
  18. 136
      src/core/ext/filters/client_channel/xds/xds_client.cc
  19. 9
      src/core/ext/filters/client_channel/xds/xds_client.h
  20. 3
      src/objective-c/grpc_objc_internal_library.bzl
  21. 1
      src/proto/grpc/channelz/BUILD
  22. 2
      src/proto/grpc/gcp/BUILD
  23. 1
      src/proto/grpc/health/v1/BUILD
  24. 1
      src/proto/grpc/lb/v1/BUILD
  25. 1
      src/proto/grpc/reflection/v1alpha/BUILD
  26. 3
      src/proto/grpc/testing/BUILD
  27. 17
      src/proto/grpc/testing/messages.proto
  28. 1
      src/proto/grpc/testing/proto2/BUILD.bazel
  29. 7
      src/proto/grpc/testing/test.proto
  30. 3
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi
  31. 17
      src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi
  32. 6
      src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi
  33. 30
      src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi
  34. 4
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi
  35. 89
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
  36. 4
      src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi
  37. 13
      src/python/grpcio/grpc/experimental/aio/__init__.py
  38. 4
      src/python/grpcio/grpc/experimental/aio/_base_call.py
  39. 67
      src/python/grpcio/grpc/experimental/aio/_call.py
  40. 70
      src/python/grpcio/grpc/experimental/aio/_channel.py
  41. 4
      src/python/grpcio/grpc/experimental/aio/_interceptor.py
  42. 54
      src/python/grpcio/grpc/experimental/aio/_server.py
  43. 6
      src/python/grpcio/grpc/experimental/aio/_typing.py
  44. 1
      src/python/grpcio_tests/tests/stress/BUILD.bazel
  45. 2
      src/python/grpcio_tests/tests_aio/tests.json
  46. 67
      src/python/grpcio_tests/tests_aio/unit/channel_ready_test.py
  47. 196
      src/python/grpcio_tests/tests_aio/unit/compression_test.py
  48. 2
      src/python/grpcio_tests/tests_aio/unit/connectivity_test.py
  49. 12
      src/python/grpcio_tests/tests_aio/unit/server_test.py
  50. 40
      test/cpp/end2end/xds_end2end_test.cc
  51. 16
      third_party/py/python_configure.bzl
  52. 13
      third_party/upb/BUILD
  53. 2
      third_party/upb/bazel/upb_proto_library.bzl
  54. 1
      third_party/upb/examples/bazel/BUILD
  55. 1
      tools/interop_matrix/client_matrix.py
  56. 36
      tools/run_tests/helper_scripts/prep_xds.sh
  57. 595
      tools/run_tests/run_xds_tests.py

@ -161,10 +161,11 @@ Please note that when using Ninja, you will still need Visual C++ (part of Visua
installed to be able to compile the C/C++ sources. installed to be able to compile the C/C++ sources.
``` ```
> @rem Run from grpc directory after cloning the repo with --recursive or updating submodules. > @rem Run from grpc directory after cloning the repo with --recursive or updating submodules.
> md .build > cd cmake
> cd .build > md build
> cd build
> call "%VS140COMNTOOLS%..\..\VC\vcvarsall.bat" x64 > call "%VS140COMNTOOLS%..\..\VC\vcvarsall.bat" x64
> cmake .. -GNinja -DCMAKE_BUILD_TYPE=Release > cmake ..\.. -GNinja -DCMAKE_BUILD_TYPE=Release
> cmake --build . > cmake --build .
``` ```
@ -183,7 +184,7 @@ ie `gRPC_CARES_PROVIDER`.
### Install after build ### Install after build
Perform the following steps to install gRPC using CMake. Perform the following steps to install gRPC using CMake.
* Set `gRPC_INSTALL` to `ON` * Set `-DgRPC_INSTALL=ON`
* Build the `install` target * Build the `install` target
The install destination is controlled by the The install destination is controlled by the
@ -196,16 +197,21 @@ in "module" mode and install them alongside gRPC in a single step.
If you are using an older version of gRPC, you will need to select "package" If you are using an older version of gRPC, you will need to select "package"
mode (rather than "module" mode) for the dependencies. mode (rather than "module" mode) for the dependencies.
This means you will need to have external copies of these libraries available This means you will need to have external copies of these libraries available
on your system. on your system. This [example](test/distrib/cpp/run_distrib_test_cmake.sh) shows
how to install dependencies with cmake before proceeding to installing gRPC itself.
``` ```
$ cmake .. -DgRPC_CARES_PROVIDER=package \ # NOTE: all of gRPC's dependencies need to be already installed
-DgRPC_PROTOBUF_PROVIDER=package \ $ cmake ../.. -DgRPC_INSTALL=ON \
-DgRPC_SSL_PROVIDER=package \ -DCMAKE_BUILD_TYPE=Release \
-DgRPC_ZLIB_PROVIDER=package -DgRPC_ABSL_PROVIDER=package \
-DgRPC_CARES_PROVIDER=package \
-DgRPC_PROTOBUF_PROVIDER=package \
-DgRPC_SSL_PROVIDER=package \
-DgRPC_ZLIB_PROVIDER=package
$ make $ make
$ make install $ make install
``` ```
[Example](test/distrib/cpp/run_distrib_test_cmake.sh)
### Cross-compiling ### Cross-compiling
@ -222,7 +228,7 @@ that will be used for this build.
This toolchain file is specified to CMake by setting the `CMAKE_TOOLCHAIN_FILE` This toolchain file is specified to CMake by setting the `CMAKE_TOOLCHAIN_FILE`
variable. variable.
``` ```
$ cmake .. -DCMAKE_TOOLCHAIN_FILE=path/to/file $ cmake ../.. -DCMAKE_TOOLCHAIN_FILE=path/to/file
$ make $ make
``` ```

@ -1,5 +1,6 @@
"""Generates and compiles C++ grpc stubs from proto_library rules.""" """Generates and compiles C++ grpc stubs from proto_library rules."""
load("@rules_proto//proto:defs.bzl", "proto_library")
load("//bazel:generate_cc.bzl", "generate_cc") load("//bazel:generate_cc.bzl", "generate_cc")
load("//bazel:protobuf.bzl", "well_known_proto_libs") load("//bazel:protobuf.bzl", "well_known_proto_libs")
@ -63,8 +64,7 @@ def cc_grpc_library(
proto_deps += [dep.split(":")[0] + ":" + "_" + dep.split(":")[1] + "_only" for dep in deps if dep.find(":") != -1] proto_deps += [dep.split(":")[0] + ":" + "_" + dep.split(":")[1] + "_only" for dep in deps if dep.find(":") != -1]
if well_known_protos: if well_known_protos:
proto_deps += well_known_proto_libs() proto_deps += well_known_proto_libs()
proto_library(
native.proto_library(
name = proto_target, name = proto_target,
srcs = srcs, srcs = srcs,
deps = proto_deps, deps = proto_deps,

@ -4,6 +4,7 @@ This is an internal rule used by cc_grpc_library, and shouldn't be used
directly. directly.
""" """
load("@rules_proto//proto:defs.bzl", "ProtoInfo")
load( load(
"//bazel:protobuf.bzl", "//bazel:protobuf.bzl",
"get_include_directory", "get_include_directory",

@ -1,3 +1,4 @@
load("@rules_proto//proto:defs.bzl", "ProtoInfo")
load( load(
"//bazel:protobuf.bzl", "//bazel:protobuf.bzl",
"get_include_directory", "get_include_directory",

@ -1,5 +1,7 @@
"""Utility functions for generating protobuf code.""" """Utility functions for generating protobuf code."""
load("@rules_proto//proto:defs.bzl", "ProtoInfo")
_PROTO_EXTENSION = ".proto" _PROTO_EXTENSION = ".proto"
_VIRTUAL_IMPORTS = "/_virtual_imports/" _VIRTUAL_IMPORTS = "/_virtual_imports/"

@ -1,5 +1,6 @@
"""Generates and compiles Python gRPC stubs from proto_library rules.""" """Generates and compiles Python gRPC stubs from proto_library rules."""
load("@rules_proto//proto:defs.bzl", "ProtoInfo")
load( load(
"//bazel:protobuf.bzl", "//bazel:protobuf.bzl",
"declare_out_files", "declare_out_files",

@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
load("@rules_proto//proto:defs.bzl", "proto_library")
load( load(
"@com_github_grpc_grpc//bazel:python_rules.bzl", "@com_github_grpc_grpc//bazel:python_rules.bzl",
"py2and3_test", "py2and3_test",

@ -4,8 +4,14 @@ gRPC environment variables
gRPC C core based implementations (those contained in this repository) expose gRPC C core based implementations (those contained in this repository) expose
some configuration as environment variables that can be set. some configuration as environment variables that can be set.
* http_proxy * grpc_proxy, https_proxy, http_proxy
The URI of the proxy to use for HTTP CONNECT support. The URI of the proxy to use for HTTP CONNECT support. These variables are
checked in order, and the first one that has a value is used.
* no_grpc_proxy, no_proxy
A comma separated list of hostnames to connect to without using a proxy even
if a proxy is set. These variables are checked in order, and the first one
that has a value is used.
* GRPC_ABORT_ON_LEAKS * GRPC_ABORT_ON_LEAKS
A debugging aid to cause a call to abort() when gRPC objects are leaked past A debugging aid to cause a call to abort() when gRPC objects are leaked past

@ -16,10 +16,11 @@ licenses(["notice"]) # 3-clause BSD
package(default_visibility = ["//visibility:public"]) package(default_visibility = ["//visibility:public"])
load("//bazel:grpc_build_system.bzl", "grpc_proto_library") load("@grpc_python_dependencies//:requirements.bzl", "requirement")
load("@rules_proto//proto:defs.bzl", "proto_library")
load("//bazel:cc_grpc_library.bzl", "cc_grpc_library") load("//bazel:cc_grpc_library.bzl", "cc_grpc_library")
load("//bazel:grpc_build_system.bzl", "grpc_proto_library")
load("//bazel:python_rules.bzl", "py_grpc_library", "py_proto_library") load("//bazel:python_rules.bzl", "py_grpc_library", "py_proto_library")
load("@grpc_python_dependencies//:requirements.bzl", "requirement")
grpc_proto_library( grpc_proto_library(
name = "auth_sample", name = "auth_sample",

@ -255,6 +255,10 @@ main loop in `HandleRpcs` to query the queue.
For a working example, refer to [greeter_async_server.cc](greeter_async_server.cc). For a working example, refer to [greeter_async_server.cc](greeter_async_server.cc).
#### Flags for the client
```sh
./greeter_client --target="a target string used to create a GRPC client channel"
```
The Default value for --target is "localhost:50051".

@ -73,11 +73,32 @@ class GreeterClient {
int main(int argc, char** argv) { int main(int argc, char** argv) {
// Instantiate the client. It requires a channel, out of which the actual RPCs // Instantiate the client. It requires a channel, out of which the actual RPCs
// are created. This channel models a connection to an endpoint (in this case, // are created. This channel models a connection to an endpoint specified by
// localhost at port 50051). We indicate that the channel isn't authenticated // the argument "--target=" which is the only expected argument.
// (use of InsecureChannelCredentials()). // We indicate that the channel isn't authenticated (use of
// InsecureChannelCredentials()).
std::string target_str;
std::string arg_str("--target");
if (argc > 1) {
std::string arg_val = argv[1];
size_t start_pos = arg_val.find(arg_str);
if (start_pos != std::string::npos) {
start_pos += arg_str.size();
if (arg_val[start_pos] == '=') {
target_str = arg_val.substr(start_pos + 1);
} else {
std::cout << "The only correct argument syntax is --target=" << std::endl;
return 0;
}
} else {
std::cout << "The only acceptable argument is --target=" << std::endl;
return 0;
}
} else {
target_str = "localhost:50051";
}
GreeterClient greeter(grpc::CreateChannel( GreeterClient greeter(grpc::CreateChannel(
"localhost:50051", grpc::InsecureChannelCredentials())); target_str, grpc::InsecureChannelCredentials()));
std::string user("world"); std::string user("world");
std::string reply = greeter.SayHello(user); std::string reply = greeter.SayHello(user);
std::cout << "Greeter received: " << reply << std::endl; std::cout << "Greeter received: " << reply << std::endl;

@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
load("@grpc_python_dependencies//:requirements.bzl", "requirement") load("@grpc_python_dependencies//:requirements.bzl", "requirement")
load("@rules_proto//proto:defs.bzl", "proto_library")
load("//bazel:python_rules.bzl", "py_grpc_library", "py_proto_library") load("//bazel:python_rules.bzl", "py_grpc_library", "py_proto_library")
package(default_testonly = 1) package(default_testonly = 1)

@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
load("@rules_proto//proto:defs.bzl", "proto_library")
load("//bazel:python_rules.bzl", "py_grpc_library", "py_proto_library") load("//bazel:python_rules.bzl", "py_grpc_library", "py_proto_library")
proto_library( proto_library(

@ -63,7 +63,7 @@ class CdsLb : public LoadBalancingPolicy {
public: public:
explicit ClusterWatcher(RefCountedPtr<CdsLb> parent) explicit ClusterWatcher(RefCountedPtr<CdsLb> parent)
: parent_(std::move(parent)) {} : parent_(std::move(parent)) {}
void OnClusterChanged(CdsUpdate cluster_data) override; void OnClusterChanged(XdsApi::CdsUpdate cluster_data) override;
void OnError(grpc_error* error) override; void OnError(grpc_error* error) override;
private: private:
@ -111,7 +111,7 @@ class CdsLb : public LoadBalancingPolicy {
// CdsLb::ClusterWatcher // CdsLb::ClusterWatcher
// //
void CdsLb::ClusterWatcher::OnClusterChanged(CdsUpdate cluster_data) { void CdsLb::ClusterWatcher::OnClusterChanged(XdsApi::CdsUpdate cluster_data) {
if (GRPC_TRACE_FLAG_ENABLED(grpc_cds_lb_trace)) { if (GRPC_TRACE_FLAG_ENABLED(grpc_cds_lb_trace)) {
gpr_log(GPR_INFO, "[cdslb %p] received CDS update from xds client", gpr_log(GPR_INFO, "[cdslb %p] received CDS update from xds client",
parent_.get()); parent_.get());

@ -172,7 +172,7 @@ class XdsLb : public LoadBalancingPolicy {
RefCountedPtr<XdsLb> xds_policy_; RefCountedPtr<XdsLb> xds_policy_;
PickerList pickers_; PickerList pickers_;
RefCountedPtr<XdsDropConfig> drop_config_; RefCountedPtr<XdsApi::DropConfig> drop_config_;
}; };
class FallbackHelper : public ChannelControlHelper { class FallbackHelper : public ChannelControlHelper {
@ -286,7 +286,7 @@ class XdsLb : public LoadBalancingPolicy {
~LocalityMap() { xds_policy_.reset(DEBUG_LOCATION, "LocalityMap"); } ~LocalityMap() { xds_policy_.reset(DEBUG_LOCATION, "LocalityMap"); }
void UpdateLocked( void UpdateLocked(
const XdsPriorityListUpdate::LocalityMap& locality_map_update); const XdsApi::PriorityListUpdate::LocalityMap& locality_map_update);
void ResetBackoffLocked(); void ResetBackoffLocked();
void UpdateXdsPickerLocked(); void UpdateXdsPickerLocked();
OrphanablePtr<Locality> ExtractLocalityLocked( OrphanablePtr<Locality> ExtractLocalityLocked(
@ -316,10 +316,10 @@ class XdsLb : public LoadBalancingPolicy {
static void OnDelayedRemovalTimerLocked(void* arg, grpc_error* error); static void OnDelayedRemovalTimerLocked(void* arg, grpc_error* error);
static void OnFailoverTimerLocked(void* arg, grpc_error* error); static void OnFailoverTimerLocked(void* arg, grpc_error* error);
const XdsPriorityListUpdate& priority_list_update() const { const XdsApi::PriorityListUpdate& priority_list_update() const {
return xds_policy_->priority_list_update_; return xds_policy_->priority_list_update_;
} }
const XdsPriorityListUpdate::LocalityMap* locality_map_update() const { const XdsApi::PriorityListUpdate::LocalityMap* locality_map_update() const {
return xds_policy_->priority_list_update_.Find(priority_); return xds_policy_->priority_list_update_.Find(priority_);
} }
@ -431,10 +431,10 @@ class XdsLb : public LoadBalancingPolicy {
// The priority that is being used. // The priority that is being used.
uint32_t current_priority_ = UINT32_MAX; uint32_t current_priority_ = UINT32_MAX;
// The update for priority_list_. // The update for priority_list_.
XdsPriorityListUpdate priority_list_update_; XdsApi::PriorityListUpdate priority_list_update_;
// The config for dropping calls. // The config for dropping calls.
RefCountedPtr<XdsDropConfig> drop_config_; RefCountedPtr<XdsApi::DropConfig> drop_config_;
// The stats for client-side load reporting. // The stats for client-side load reporting.
XdsClientStats client_stats_; XdsClientStats client_stats_;
@ -594,7 +594,7 @@ class XdsLb::EndpointWatcher : public XdsClient::EndpointWatcherInterface {
~EndpointWatcher() { xds_policy_.reset(DEBUG_LOCATION, "EndpointWatcher"); } ~EndpointWatcher() { xds_policy_.reset(DEBUG_LOCATION, "EndpointWatcher"); }
void OnEndpointChanged(EdsUpdate update) override { void OnEndpointChanged(XdsApi::EdsUpdate update) override {
if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_xds_trace)) { if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_xds_trace)) {
gpr_log(GPR_INFO, "[xdslb %p] Received EDS update from xds client", gpr_log(GPR_INFO, "[xdslb %p] Received EDS update from xds client",
xds_policy_.get()); xds_policy_.get());
@ -1032,6 +1032,8 @@ void XdsLb::UpdatePrioritiesLocked() {
for (uint32_t priority = 0; priority < priorities_.size(); ++priority) { for (uint32_t priority = 0; priority < priorities_.size(); ++priority) {
LocalityMap* locality_map = priorities_[priority].get(); LocalityMap* locality_map = priorities_[priority].get();
const auto* locality_map_update = priority_list_update_.Find(priority); const auto* locality_map_update = priority_list_update_.Find(priority);
// If we have more current priorities than exist in the update, stop here.
if (locality_map_update == nullptr) break;
// Propagate locality_map_update. // Propagate locality_map_update.
// TODO(juanlishen): Find a clean way to skip duplicate update for a // TODO(juanlishen): Find a clean way to skip duplicate update for a
// priority. // priority.
@ -1154,7 +1156,7 @@ XdsLb::LocalityMap::LocalityMap(RefCountedPtr<XdsLb> xds_policy,
} }
void XdsLb::LocalityMap::UpdateLocked( void XdsLb::LocalityMap::UpdateLocked(
const XdsPriorityListUpdate::LocalityMap& locality_map_update) { const XdsApi::PriorityListUpdate::LocalityMap& locality_map_update) {
if (xds_policy_->shutting_down_) return; if (xds_policy_->shutting_down_) return;
if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_xds_trace)) { if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_xds_trace)) {
gpr_log(GPR_INFO, "[xdslb %p] Start Updating priority %" PRIu32, gpr_log(GPR_INFO, "[xdslb %p] Start Updating priority %" PRIu32,

@ -56,8 +56,12 @@
namespace grpc_core { namespace grpc_core {
bool XdsPriorityListUpdate::operator==( //
const XdsPriorityListUpdate& other) const { // XdsApi::PriorityListUpdate
//
bool XdsApi::PriorityListUpdate::operator==(
const XdsApi::PriorityListUpdate& other) const {
if (priorities_.size() != other.priorities_.size()) return false; if (priorities_.size() != other.priorities_.size()) return false;
for (size_t i = 0; i < priorities_.size(); ++i) { for (size_t i = 0; i < priorities_.size(); ++i) {
if (priorities_[i].localities != other.priorities_[i].localities) { if (priorities_[i].localities != other.priorities_[i].localities) {
@ -67,8 +71,8 @@ bool XdsPriorityListUpdate::operator==(
return true; return true;
} }
void XdsPriorityListUpdate::Add( void XdsApi::PriorityListUpdate::Add(
XdsPriorityListUpdate::LocalityMap::Locality locality) { XdsApi::PriorityListUpdate::LocalityMap::Locality locality) {
// Pad the missing priorities in case the localities are not ordered by // Pad the missing priorities in case the localities are not ordered by
// priority. // priority.
if (!Contains(locality.priority)) priorities_.resize(locality.priority + 1); if (!Contains(locality.priority)) priorities_.resize(locality.priority + 1);
@ -76,13 +80,13 @@ void XdsPriorityListUpdate::Add(
locality_map.localities.emplace(locality.name, std::move(locality)); locality_map.localities.emplace(locality.name, std::move(locality));
} }
const XdsPriorityListUpdate::LocalityMap* XdsPriorityListUpdate::Find( const XdsApi::PriorityListUpdate::LocalityMap* XdsApi::PriorityListUpdate::Find(
uint32_t priority) const { uint32_t priority) const {
if (!Contains(priority)) return nullptr; if (!Contains(priority)) return nullptr;
return &priorities_[priority]; return &priorities_[priority];
} }
bool XdsPriorityListUpdate::Contains( bool XdsApi::PriorityListUpdate::Contains(
const RefCountedPtr<XdsLocalityName>& name) { const RefCountedPtr<XdsLocalityName>& name) {
for (size_t i = 0; i < priorities_.size(); ++i) { for (size_t i = 0; i < priorities_.size(); ++i) {
const LocalityMap& locality_map = priorities_[i]; const LocalityMap& locality_map = priorities_[i];
@ -91,7 +95,11 @@ bool XdsPriorityListUpdate::Contains(
return false; return false;
} }
bool XdsDropConfig::ShouldDrop(const std::string** category_name) const { //
// XdsApi::DropConfig
//
bool XdsApi::DropConfig::ShouldDrop(const std::string** category_name) const {
for (size_t i = 0; i < drop_category_list_.size(); ++i) { for (size_t i = 0; i < drop_category_list_.size(); ++i) {
const auto& drop_category = drop_category_list_[i]; const auto& drop_category = drop_category_list_[i];
// Generate a random number in [0, 1000000). // Generate a random number in [0, 1000000).
@ -104,6 +112,17 @@ bool XdsDropConfig::ShouldDrop(const std::string** category_name) const {
return false; return false;
} }
//
// XdsApi
//
const char* XdsApi::kLdsTypeUrl = "type.googleapis.com/envoy.api.v2.Listener";
const char* XdsApi::kRdsTypeUrl =
"type.googleapis.com/envoy.api.v2.RouteConfiguration";
const char* XdsApi::kCdsTypeUrl = "type.googleapis.com/envoy.api.v2.Cluster";
const char* XdsApi::kEdsTypeUrl =
"type.googleapis.com/envoy.api.v2.ClusterLoadAssignment";
namespace { namespace {
void PopulateMetadataValue(upb_arena* arena, google_protobuf_Value* value_pb, void PopulateMetadataValue(upb_arena* arena, google_protobuf_Value* value_pb,
@ -203,67 +222,21 @@ void PopulateNode(upb_arena* arena, const XdsBootstrap::Node* node,
upb_strview_makez(build_version)); upb_strview_makez(build_version));
} }
} // namespace envoy_api_v2_DiscoveryRequest* CreateDiscoveryRequest(
upb_arena* arena, const char* type_url, const std::string& version,
grpc_slice XdsUnsupportedTypeNackRequestCreateAndEncode( const std::string& nonce, grpc_error* error, const XdsBootstrap::Node* node,
const std::string& type_url, const std::string& nonce, grpc_error* error) { const char* build_version) {
upb::Arena arena;
// Create a request. // Create a request.
envoy_api_v2_DiscoveryRequest* request = envoy_api_v2_DiscoveryRequest* request =
envoy_api_v2_DiscoveryRequest_new(arena.ptr()); envoy_api_v2_DiscoveryRequest_new(arena);
// Set type_url. // Set type_url.
envoy_api_v2_DiscoveryRequest_set_type_url( envoy_api_v2_DiscoveryRequest_set_type_url(request,
request, upb_strview_makez(type_url.c_str())); upb_strview_makez(type_url));
// Set nonce.
envoy_api_v2_DiscoveryRequest_set_response_nonce(
request, upb_strview_makez(nonce.c_str()));
// Set error_detail.
grpc_slice error_description_slice;
GPR_ASSERT(grpc_error_get_str(error, GRPC_ERROR_STR_DESCRIPTION,
&error_description_slice));
upb_strview error_description_strview =
upb_strview_make(reinterpret_cast<const char*>(
GPR_SLICE_START_PTR(error_description_slice)),
GPR_SLICE_LENGTH(error_description_slice));
google_rpc_Status* error_detail =
envoy_api_v2_DiscoveryRequest_mutable_error_detail(request, arena.ptr());
google_rpc_Status_set_message(error_detail, error_description_strview);
GRPC_ERROR_UNREF(error);
// Encode the request.
size_t output_length;
char* output = envoy_api_v2_DiscoveryRequest_serialize(request, arena.ptr(),
&output_length);
return grpc_slice_from_copied_buffer(output, output_length);
}
grpc_slice XdsLdsRequestCreateAndEncode(const std::string& server_name,
const XdsBootstrap::Node* node,
const char* build_version,
const std::string& version,
const std::string& nonce,
grpc_error* error) {
upb::Arena arena;
// Create a request.
envoy_api_v2_DiscoveryRequest* request =
envoy_api_v2_DiscoveryRequest_new(arena.ptr());
// Set version_info. // Set version_info.
if (!version.empty()) { if (!version.empty()) {
envoy_api_v2_DiscoveryRequest_set_version_info( envoy_api_v2_DiscoveryRequest_set_version_info(
request, upb_strview_makez(version.c_str())); request, upb_strview_makez(version.c_str()));
} }
// Populate node.
if (build_version != nullptr) {
envoy_api_v2_core_Node* node_msg =
envoy_api_v2_DiscoveryRequest_mutable_node(request, arena.ptr());
PopulateNode(arena.ptr(), node, build_version, node_msg);
}
// Add resource_name.
envoy_api_v2_DiscoveryRequest_add_resource_names(
request, upb_strview_make(server_name.data(), server_name.size()),
arena.ptr());
// Set type_url.
envoy_api_v2_DiscoveryRequest_set_type_url(request,
upb_strview_makez(kLdsTypeUrl));
// Set nonce. // Set nonce.
if (!nonce.empty()) { if (!nonce.empty()) {
envoy_api_v2_DiscoveryRequest_set_response_nonce( envoy_api_v2_DiscoveryRequest_set_response_nonce(
@ -279,148 +252,98 @@ grpc_slice XdsLdsRequestCreateAndEncode(const std::string& server_name,
GPR_SLICE_START_PTR(error_description_slice)), GPR_SLICE_START_PTR(error_description_slice)),
GPR_SLICE_LENGTH(error_description_slice)); GPR_SLICE_LENGTH(error_description_slice));
google_rpc_Status* error_detail = google_rpc_Status* error_detail =
envoy_api_v2_DiscoveryRequest_mutable_error_detail(request, envoy_api_v2_DiscoveryRequest_mutable_error_detail(request, arena);
arena.ptr());
google_rpc_Status_set_message(error_detail, error_description_strview); google_rpc_Status_set_message(error_detail, error_description_strview);
GRPC_ERROR_UNREF(error); GRPC_ERROR_UNREF(error);
} }
// Encode the request. // Populate node.
if (build_version != nullptr) {
envoy_api_v2_core_Node* node_msg =
envoy_api_v2_DiscoveryRequest_mutable_node(request, arena);
PopulateNode(arena, node, build_version, node_msg);
}
return request;
}
grpc_slice SerializeDiscoveryRequest(upb_arena* arena,
envoy_api_v2_DiscoveryRequest* request) {
size_t output_length; size_t output_length;
char* output = envoy_api_v2_DiscoveryRequest_serialize(request, arena.ptr(), char* output =
&output_length); envoy_api_v2_DiscoveryRequest_serialize(request, arena, &output_length);
return grpc_slice_from_copied_buffer(output, output_length); return grpc_slice_from_copied_buffer(output, output_length);
} }
grpc_slice XdsRdsRequestCreateAndEncode(const std::string& route_config_name, } // namespace
const XdsBootstrap::Node* node,
const char* build_version, grpc_slice XdsApi::CreateUnsupportedTypeNackRequest(const std::string& type_url,
const std::string& version, const std::string& nonce,
const std::string& nonce, grpc_error* error) {
grpc_error* error) { upb::Arena arena;
envoy_api_v2_DiscoveryRequest* request = CreateDiscoveryRequest(
arena.ptr(), type_url.c_str(), /*version=*/"", nonce, error,
/*node=*/nullptr, /*build_version=*/nullptr);
return SerializeDiscoveryRequest(arena.ptr(), request);
}
grpc_slice XdsApi::CreateLdsRequest(const std::string& server_name,
const std::string& version,
const std::string& nonce, grpc_error* error,
bool populate_node) {
upb::Arena arena; upb::Arena arena;
// Create a request.
envoy_api_v2_DiscoveryRequest* request = envoy_api_v2_DiscoveryRequest* request =
envoy_api_v2_DiscoveryRequest_new(arena.ptr()); CreateDiscoveryRequest(arena.ptr(), kLdsTypeUrl, version, nonce, error,
// Set version_info. populate_node ? node_ : nullptr,
if (!version.empty()) { populate_node ? build_version_ : nullptr);
envoy_api_v2_DiscoveryRequest_set_version_info( // Add resource_name.
request, upb_strview_makez(version.c_str())); envoy_api_v2_DiscoveryRequest_add_resource_names(
} request, upb_strview_make(server_name.data(), server_name.size()),
// Populate node. arena.ptr());
if (build_version != nullptr) { return SerializeDiscoveryRequest(arena.ptr(), request);
envoy_api_v2_core_Node* node_msg = }
envoy_api_v2_DiscoveryRequest_mutable_node(request, arena.ptr());
PopulateNode(arena.ptr(), node, build_version, node_msg); grpc_slice XdsApi::CreateRdsRequest(const std::string& route_config_name,
} const std::string& version,
const std::string& nonce, grpc_error* error,
bool populate_node) {
upb::Arena arena;
envoy_api_v2_DiscoveryRequest* request =
CreateDiscoveryRequest(arena.ptr(), kRdsTypeUrl, version, nonce, error,
populate_node ? node_ : nullptr,
populate_node ? build_version_ : nullptr);
// Add resource_name. // Add resource_name.
envoy_api_v2_DiscoveryRequest_add_resource_names( envoy_api_v2_DiscoveryRequest_add_resource_names(
request, request,
upb_strview_make(route_config_name.data(), route_config_name.size()), upb_strview_make(route_config_name.data(), route_config_name.size()),
arena.ptr()); arena.ptr());
// Set type_url. return SerializeDiscoveryRequest(arena.ptr(), request);
envoy_api_v2_DiscoveryRequest_set_type_url(request,
upb_strview_makez(kRdsTypeUrl));
// Set nonce.
if (!nonce.empty()) {
envoy_api_v2_DiscoveryRequest_set_response_nonce(
request, upb_strview_makez(nonce.c_str()));
}
// Set error_detail if it's a NACK.
if (error != GRPC_ERROR_NONE) {
grpc_slice error_description_slice;
GPR_ASSERT(grpc_error_get_str(error, GRPC_ERROR_STR_DESCRIPTION,
&error_description_slice));
upb_strview error_description_strview =
upb_strview_make(reinterpret_cast<const char*>(
GPR_SLICE_START_PTR(error_description_slice)),
GPR_SLICE_LENGTH(error_description_slice));
google_rpc_Status* error_detail =
envoy_api_v2_DiscoveryRequest_mutable_error_detail(request,
arena.ptr());
google_rpc_Status_set_message(error_detail, error_description_strview);
GRPC_ERROR_UNREF(error);
}
// Encode the request.
size_t output_length;
char* output = envoy_api_v2_DiscoveryRequest_serialize(request, arena.ptr(),
&output_length);
return grpc_slice_from_copied_buffer(output, output_length);
} }
grpc_slice XdsCdsRequestCreateAndEncode( grpc_slice XdsApi::CreateCdsRequest(const std::set<StringView>& cluster_names,
const std::set<StringView>& cluster_names, const XdsBootstrap::Node* node, const std::string& version,
const char* build_version, const std::string& version, const std::string& nonce, grpc_error* error,
const std::string& nonce, grpc_error* error) { bool populate_node) {
upb::Arena arena; upb::Arena arena;
// Create a request.
envoy_api_v2_DiscoveryRequest* request = envoy_api_v2_DiscoveryRequest* request =
envoy_api_v2_DiscoveryRequest_new(arena.ptr()); CreateDiscoveryRequest(arena.ptr(), kCdsTypeUrl, version, nonce, error,
// Set version_info. populate_node ? node_ : nullptr,
if (!version.empty()) { populate_node ? build_version_ : nullptr);
envoy_api_v2_DiscoveryRequest_set_version_info(
request, upb_strview_makez(version.c_str()));
}
// Populate node.
if (build_version != nullptr) {
envoy_api_v2_core_Node* node_msg =
envoy_api_v2_DiscoveryRequest_mutable_node(request, arena.ptr());
PopulateNode(arena.ptr(), node, build_version, node_msg);
}
// Add resource_names. // Add resource_names.
for (const auto& cluster_name : cluster_names) { for (const auto& cluster_name : cluster_names) {
envoy_api_v2_DiscoveryRequest_add_resource_names( envoy_api_v2_DiscoveryRequest_add_resource_names(
request, upb_strview_make(cluster_name.data(), cluster_name.size()), request, upb_strview_make(cluster_name.data(), cluster_name.size()),
arena.ptr()); arena.ptr());
} }
// Set type_url. return SerializeDiscoveryRequest(arena.ptr(), request);
envoy_api_v2_DiscoveryRequest_set_type_url(request,
upb_strview_makez(kCdsTypeUrl));
// Set nonce.
if (!nonce.empty()) {
envoy_api_v2_DiscoveryRequest_set_response_nonce(
request, upb_strview_makez(nonce.c_str()));
}
// Set error_detail if it's a NACK.
if (error != GRPC_ERROR_NONE) {
grpc_slice error_description_slice;
GPR_ASSERT(grpc_error_get_str(error, GRPC_ERROR_STR_DESCRIPTION,
&error_description_slice));
upb_strview error_description_strview =
upb_strview_make(reinterpret_cast<const char*>(
GPR_SLICE_START_PTR(error_description_slice)),
GPR_SLICE_LENGTH(error_description_slice));
google_rpc_Status* error_detail =
envoy_api_v2_DiscoveryRequest_mutable_error_detail(request,
arena.ptr());
google_rpc_Status_set_message(error_detail, error_description_strview);
GRPC_ERROR_UNREF(error);
}
// Encode the request.
size_t output_length;
char* output = envoy_api_v2_DiscoveryRequest_serialize(request, arena.ptr(),
&output_length);
return grpc_slice_from_copied_buffer(output, output_length);
} }
grpc_slice XdsEdsRequestCreateAndEncode( grpc_slice XdsApi::CreateEdsRequest(
const std::set<StringView>& eds_service_names, const std::set<StringView>& eds_service_names, const std::string& version,
const XdsBootstrap::Node* node, const char* build_version, const std::string& nonce, grpc_error* error, bool populate_node) {
const std::string& version, const std::string& nonce, grpc_error* error) {
upb::Arena arena; upb::Arena arena;
// Create a request.
envoy_api_v2_DiscoveryRequest* request = envoy_api_v2_DiscoveryRequest* request =
envoy_api_v2_DiscoveryRequest_new(arena.ptr()); CreateDiscoveryRequest(arena.ptr(), kEdsTypeUrl, version, nonce, error,
// Set version_info. populate_node ? node_ : nullptr,
if (!version.empty()) { populate_node ? build_version_ : nullptr);
envoy_api_v2_DiscoveryRequest_set_version_info(
request, upb_strview_makez(version.c_str()));
}
// Populate node.
if (build_version != nullptr) {
envoy_api_v2_core_Node* node_msg =
envoy_api_v2_DiscoveryRequest_mutable_node(request, arena.ptr());
PopulateNode(arena.ptr(), node, build_version, node_msg);
}
// Add resource_names. // Add resource_names.
for (const auto& eds_service_name : eds_service_names) { for (const auto& eds_service_name : eds_service_names) {
envoy_api_v2_DiscoveryRequest_add_resource_names( envoy_api_v2_DiscoveryRequest_add_resource_names(
@ -428,34 +351,7 @@ grpc_slice XdsEdsRequestCreateAndEncode(
upb_strview_make(eds_service_name.data(), eds_service_name.size()), upb_strview_make(eds_service_name.data(), eds_service_name.size()),
arena.ptr()); arena.ptr());
} }
// Set type_url. return SerializeDiscoveryRequest(arena.ptr(), request);
envoy_api_v2_DiscoveryRequest_set_type_url(request,
upb_strview_makez(kEdsTypeUrl));
// Set nonce.
if (!nonce.empty()) {
envoy_api_v2_DiscoveryRequest_set_response_nonce(
request, upb_strview_makez(nonce.c_str()));
}
// Set error_detail if it's a NACK.
if (error != GRPC_ERROR_NONE) {
grpc_slice error_description_slice;
GPR_ASSERT(grpc_error_get_str(error, GRPC_ERROR_STR_DESCRIPTION,
&error_description_slice));
upb_strview error_description_strview =
upb_strview_make(reinterpret_cast<const char*>(
GPR_SLICE_START_PTR(error_description_slice)),
GPR_SLICE_LENGTH(error_description_slice));
google_rpc_Status* error_detail =
envoy_api_v2_DiscoveryRequest_mutable_error_detail(request,
arena.ptr());
google_rpc_Status_set_message(error_detail, error_description_strview);
GRPC_ERROR_UNREF(error);
}
// Encode the request.
size_t output_length;
char* output = envoy_api_v2_DiscoveryRequest_serialize(request, arena.ptr(),
&output_length);
return grpc_slice_from_copied_buffer(output, output_length);
} }
namespace { namespace {
@ -511,7 +407,7 @@ MatchType DomainPatternMatchType(const std::string& domain_pattern) {
grpc_error* RouteConfigParse( grpc_error* RouteConfigParse(
const envoy_api_v2_RouteConfiguration* route_config, const envoy_api_v2_RouteConfiguration* route_config,
const std::string& expected_server_name, RdsUpdate* rds_update) { const std::string& expected_server_name, XdsApi::RdsUpdate* rds_update) {
// Strip off port from server name, if any. // Strip off port from server name, if any.
size_t pos = expected_server_name.find(':'); size_t pos = expected_server_name.find(':');
std::string expected_host_name = expected_server_name.substr(0, pos); std::string expected_host_name = expected_server_name.substr(0, pos);
@ -604,11 +500,9 @@ grpc_error* RouteConfigParse(
return GRPC_ERROR_NONE; return GRPC_ERROR_NONE;
} }
} // namespace
grpc_error* LdsResponseParse(const envoy_api_v2_DiscoveryResponse* response, grpc_error* LdsResponseParse(const envoy_api_v2_DiscoveryResponse* response,
const std::string& expected_server_name, const std::string& expected_server_name,
LdsUpdate* lds_update, upb_arena* arena) { XdsApi::LdsUpdate* lds_update, upb_arena* arena) {
// Get the resources from the response. // Get the resources from the response.
size_t size; size_t size;
const google_protobuf_Any* const* resources = const google_protobuf_Any* const* resources =
@ -620,7 +514,7 @@ grpc_error* LdsResponseParse(const envoy_api_v2_DiscoveryResponse* response,
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
// Check the type_url of the resource. // Check the type_url of the resource.
const upb_strview type_url = google_protobuf_Any_type_url(resources[i]); const upb_strview type_url = google_protobuf_Any_type_url(resources[i]);
if (!upb_strview_eql(type_url, upb_strview_makez(kLdsTypeUrl))) { if (!upb_strview_eql(type_url, upb_strview_makez(XdsApi::kLdsTypeUrl))) {
return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Resource is not LDS."); return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Resource is not LDS.");
} }
// Decode the listener. // Decode the listener.
@ -655,7 +549,7 @@ grpc_error* LdsResponseParse(const envoy_api_v2_DiscoveryResponse* response,
const envoy_api_v2_RouteConfiguration* route_config = const envoy_api_v2_RouteConfiguration* route_config =
envoy_config_filter_network_http_connection_manager_v2_HttpConnectionManager_route_config( envoy_config_filter_network_http_connection_manager_v2_HttpConnectionManager_route_config(
http_connection_manager); http_connection_manager);
RdsUpdate rds_update; XdsApi::RdsUpdate rds_update;
grpc_error* error = grpc_error* error =
RouteConfigParse(route_config, expected_server_name, &rds_update); RouteConfigParse(route_config, expected_server_name, &rds_update);
if (error != GRPC_ERROR_NONE) return error; if (error != GRPC_ERROR_NONE) return error;
@ -690,7 +584,7 @@ grpc_error* LdsResponseParse(const envoy_api_v2_DiscoveryResponse* response,
grpc_error* RdsResponseParse(const envoy_api_v2_DiscoveryResponse* response, grpc_error* RdsResponseParse(const envoy_api_v2_DiscoveryResponse* response,
const std::string& expected_server_name, const std::string& expected_server_name,
const std::string& expected_route_config_name, const std::string& expected_route_config_name,
RdsUpdate* rds_update, upb_arena* arena) { XdsApi::RdsUpdate* rds_update, upb_arena* arena) {
// Get the resources from the response. // Get the resources from the response.
size_t size; size_t size;
const google_protobuf_Any* const* resources = const google_protobuf_Any* const* resources =
@ -702,7 +596,7 @@ grpc_error* RdsResponseParse(const envoy_api_v2_DiscoveryResponse* response,
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
// Check the type_url of the resource. // Check the type_url of the resource.
const upb_strview type_url = google_protobuf_Any_type_url(resources[i]); const upb_strview type_url = google_protobuf_Any_type_url(resources[i]);
if (!upb_strview_eql(type_url, upb_strview_makez(kRdsTypeUrl))) { if (!upb_strview_eql(type_url, upb_strview_makez(XdsApi::kRdsTypeUrl))) {
return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Resource is not RDS."); return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Resource is not RDS.");
} }
// Decode the route_config. // Decode the route_config.
@ -720,7 +614,7 @@ grpc_error* RdsResponseParse(const envoy_api_v2_DiscoveryResponse* response,
upb_strview_makez(expected_route_config_name.c_str()); upb_strview_makez(expected_route_config_name.c_str());
if (!upb_strview_eql(name, expected_name)) continue; if (!upb_strview_eql(name, expected_name)) continue;
// Parse the route_config. // Parse the route_config.
RdsUpdate local_rds_update; XdsApi::RdsUpdate local_rds_update;
grpc_error* error = grpc_error* error =
RouteConfigParse(route_config, expected_server_name, &local_rds_update); RouteConfigParse(route_config, expected_server_name, &local_rds_update);
if (error != GRPC_ERROR_NONE) return error; if (error != GRPC_ERROR_NONE) return error;
@ -732,7 +626,8 @@ grpc_error* RdsResponseParse(const envoy_api_v2_DiscoveryResponse* response,
} }
grpc_error* CdsResponseParse(const envoy_api_v2_DiscoveryResponse* response, grpc_error* CdsResponseParse(const envoy_api_v2_DiscoveryResponse* response,
CdsUpdateMap* cds_update_map, upb_arena* arena) { XdsApi::CdsUpdateMap* cds_update_map,
upb_arena* arena) {
// Get the resources from the response. // Get the resources from the response.
size_t size; size_t size;
const google_protobuf_Any* const* resources = const google_protobuf_Any* const* resources =
@ -743,10 +638,10 @@ grpc_error* CdsResponseParse(const envoy_api_v2_DiscoveryResponse* response,
} }
// Parse all the resources in the CDS response. // Parse all the resources in the CDS response.
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
CdsUpdate cds_update; XdsApi::CdsUpdate cds_update;
// Check the type_url of the resource. // Check the type_url of the resource.
const upb_strview type_url = google_protobuf_Any_type_url(resources[i]); const upb_strview type_url = google_protobuf_Any_type_url(resources[i]);
if (!upb_strview_eql(type_url, upb_strview_makez(kCdsTypeUrl))) { if (!upb_strview_eql(type_url, upb_strview_makez(XdsApi::kCdsTypeUrl))) {
return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Resource is not CDS."); return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Resource is not CDS.");
} }
// Decode the cluster. // Decode the cluster.
@ -801,8 +696,6 @@ grpc_error* CdsResponseParse(const envoy_api_v2_DiscoveryResponse* response,
return GRPC_ERROR_NONE; return GRPC_ERROR_NONE;
} }
namespace {
grpc_error* ServerAddressParseAndAppend( grpc_error* ServerAddressParseAndAppend(
const envoy_api_v2_endpoint_LbEndpoint* lb_endpoint, const envoy_api_v2_endpoint_LbEndpoint* lb_endpoint,
ServerAddressList* list) { ServerAddressList* list) {
@ -840,7 +733,7 @@ grpc_error* ServerAddressParseAndAppend(
grpc_error* LocalityParse( grpc_error* LocalityParse(
const envoy_api_v2_endpoint_LocalityLbEndpoints* locality_lb_endpoints, const envoy_api_v2_endpoint_LocalityLbEndpoints* locality_lb_endpoints,
XdsPriorityListUpdate::LocalityMap::Locality* output_locality) { XdsApi::PriorityListUpdate::LocalityMap::Locality* output_locality) {
// Parse LB weight. // Parse LB weight.
const google_protobuf_UInt32Value* lb_weight = const google_protobuf_UInt32Value* lb_weight =
envoy_api_v2_endpoint_LocalityLbEndpoints_load_balancing_weight( envoy_api_v2_endpoint_LocalityLbEndpoints_load_balancing_weight(
@ -878,7 +771,7 @@ grpc_error* LocalityParse(
grpc_error* DropParseAndAppend( grpc_error* DropParseAndAppend(
const envoy_api_v2_ClusterLoadAssignment_Policy_DropOverload* drop_overload, const envoy_api_v2_ClusterLoadAssignment_Policy_DropOverload* drop_overload,
XdsDropConfig* drop_config, bool* drop_all) { XdsApi::DropConfig* drop_config, bool* drop_all) {
// Get the category. // Get the category.
upb_strview category = upb_strview category =
envoy_api_v2_ClusterLoadAssignment_Policy_DropOverload_category( envoy_api_v2_ClusterLoadAssignment_Policy_DropOverload_category(
@ -918,7 +811,7 @@ grpc_error* DropParseAndAppend(
grpc_error* EdsResponsedParse( grpc_error* EdsResponsedParse(
const envoy_api_v2_DiscoveryResponse* response, const envoy_api_v2_DiscoveryResponse* response,
const std::set<StringView>& expected_eds_service_names, const std::set<StringView>& expected_eds_service_names,
EdsUpdateMap* eds_update_map, upb_arena* arena) { XdsApi::EdsUpdateMap* eds_update_map, upb_arena* arena) {
// Get the resources from the response. // Get the resources from the response.
size_t size; size_t size;
const google_protobuf_Any* const* resources = const google_protobuf_Any* const* resources =
@ -928,10 +821,10 @@ grpc_error* EdsResponsedParse(
"EDS response contains 0 resource."); "EDS response contains 0 resource.");
} }
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
EdsUpdate eds_update; XdsApi::EdsUpdate eds_update;
// Check the type_url of the resource. // Check the type_url of the resource.
upb_strview type_url = google_protobuf_Any_type_url(resources[i]); upb_strview type_url = google_protobuf_Any_type_url(resources[i]);
if (!upb_strview_eql(type_url, upb_strview_makez(kEdsTypeUrl))) { if (!upb_strview_eql(type_url, upb_strview_makez(XdsApi::kEdsTypeUrl))) {
return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Resource is not EDS."); return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Resource is not EDS.");
} }
// Get the cluster_load_assignment. // Get the cluster_load_assignment.
@ -960,7 +853,7 @@ grpc_error* EdsResponsedParse(
envoy_api_v2_ClusterLoadAssignment_endpoints(cluster_load_assignment, envoy_api_v2_ClusterLoadAssignment_endpoints(cluster_load_assignment,
&locality_size); &locality_size);
for (size_t j = 0; j < locality_size; ++j) { for (size_t j = 0; j < locality_size; ++j) {
XdsPriorityListUpdate::LocalityMap::Locality locality; XdsApi::PriorityListUpdate::LocalityMap::Locality locality;
grpc_error* error = LocalityParse(endpoints[j], &locality); grpc_error* error = LocalityParse(endpoints[j], &locality);
if (error != GRPC_ERROR_NONE) return error; if (error != GRPC_ERROR_NONE) return error;
// Filter out locality with weight 0. // Filter out locality with weight 0.
@ -968,7 +861,7 @@ grpc_error* EdsResponsedParse(
eds_update.priority_list_update.Add(locality); eds_update.priority_list_update.Add(locality);
} }
// Get the drop config. // Get the drop config.
eds_update.drop_config = MakeRefCounted<XdsDropConfig>(); eds_update.drop_config = MakeRefCounted<XdsApi::DropConfig>();
const envoy_api_v2_ClusterLoadAssignment_Policy* policy = const envoy_api_v2_ClusterLoadAssignment_Policy* policy =
envoy_api_v2_ClusterLoadAssignment_policy(cluster_load_assignment); envoy_api_v2_ClusterLoadAssignment_policy(cluster_load_assignment);
if (policy != nullptr) { if (policy != nullptr) {
@ -998,7 +891,7 @@ grpc_error* EdsResponsedParse(
} // namespace } // namespace
grpc_error* XdsAdsResponseDecodeAndParse( grpc_error* XdsApi::ParseAdsResponse(
const grpc_slice& encoded_response, const std::string& expected_server_name, const grpc_slice& encoded_response, const std::string& expected_server_name,
const std::string& expected_route_config_name, const std::string& expected_route_config_name,
const std::set<StringView>& expected_eds_service_names, const std::set<StringView>& expected_eds_service_names,
@ -1047,7 +940,7 @@ grpc_error* XdsAdsResponseDecodeAndParse(
namespace { namespace {
grpc_slice LrsRequestEncode( grpc_slice SerializeLrsRequest(
const envoy_service_load_stats_v2_LoadStatsRequest* request, const envoy_service_load_stats_v2_LoadStatsRequest* request,
upb_arena* arena) { upb_arena* arena) {
size_t output_length; size_t output_length;
@ -1058,9 +951,7 @@ grpc_slice LrsRequestEncode(
} // namespace } // namespace
grpc_slice XdsLrsRequestCreateAndEncode(const std::string& server_name, grpc_slice XdsApi::CreateLrsInitialRequest(const std::string& server_name) {
const XdsBootstrap::Node* node,
const char* build_version) {
upb::Arena arena; upb::Arena arena;
// Create a request. // Create a request.
envoy_service_load_stats_v2_LoadStatsRequest* request = envoy_service_load_stats_v2_LoadStatsRequest* request =
@ -1069,7 +960,7 @@ grpc_slice XdsLrsRequestCreateAndEncode(const std::string& server_name,
envoy_api_v2_core_Node* node_msg = envoy_api_v2_core_Node* node_msg =
envoy_service_load_stats_v2_LoadStatsRequest_mutable_node(request, envoy_service_load_stats_v2_LoadStatsRequest_mutable_node(request,
arena.ptr()); arena.ptr());
PopulateNode(arena.ptr(), node, build_version, node_msg); PopulateNode(arena.ptr(), node_, build_version_, node_msg);
// Add cluster stats. There is only one because we only use one server name in // Add cluster stats. There is only one because we only use one server name in
// one channel. // one channel.
envoy_api_v2_endpoint_ClusterStats* cluster_stats = envoy_api_v2_endpoint_ClusterStats* cluster_stats =
@ -1078,7 +969,7 @@ grpc_slice XdsLrsRequestCreateAndEncode(const std::string& server_name,
// Set the cluster name. // Set the cluster name.
envoy_api_v2_endpoint_ClusterStats_set_cluster_name( envoy_api_v2_endpoint_ClusterStats_set_cluster_name(
cluster_stats, upb_strview_makez(server_name.c_str())); cluster_stats, upb_strview_makez(server_name.c_str()));
return LrsRequestEncode(request, arena.ptr()); return SerializeLrsRequest(request, arena.ptr());
} }
namespace { namespace {
@ -1123,7 +1014,7 @@ void LocalityStatsPopulate(
} // namespace } // namespace
grpc_slice XdsLrsRequestCreateAndEncode( grpc_slice XdsApi::CreateLrsRequest(
std::map<StringView, std::set<XdsClientStats*>, StringLess> std::map<StringView, std::set<XdsClientStats*>, StringLess>
client_stats_map) { client_stats_map) {
upb::Arena arena; upb::Arena arena;
@ -1193,12 +1084,12 @@ grpc_slice XdsLrsRequestCreateAndEncode(
timespec.tv_nsec); timespec.tv_nsec);
} }
} }
return LrsRequestEncode(request, arena.ptr()); return SerializeLrsRequest(request, arena.ptr());
} }
grpc_error* XdsLrsResponseDecodeAndParse(const grpc_slice& encoded_response, grpc_error* XdsApi::ParseLrsResponse(const grpc_slice& encoded_response,
std::set<std::string>* cluster_names, std::set<std::string>* cluster_names,
grpc_millis* load_reporting_interval) { grpc_millis* load_reporting_interval) {
upb::Arena arena; upb::Arena arena;
// Decode the response. // Decode the response.
const envoy_service_load_stats_v2_LoadStatsResponse* decoded_response = const envoy_service_load_stats_v2_LoadStatsResponse* decoded_response =

@ -34,215 +34,218 @@
namespace grpc_core { namespace grpc_core {
constexpr char kLdsTypeUrl[] = "type.googleapis.com/envoy.api.v2.Listener"; class XdsApi {
constexpr char kRdsTypeUrl[] = public:
"type.googleapis.com/envoy.api.v2.RouteConfiguration"; static const char* kLdsTypeUrl;
constexpr char kCdsTypeUrl[] = "type.googleapis.com/envoy.api.v2.Cluster"; static const char* kRdsTypeUrl;
constexpr char kEdsTypeUrl[] = static const char* kCdsTypeUrl;
"type.googleapis.com/envoy.api.v2.ClusterLoadAssignment"; static const char* kEdsTypeUrl;
struct RdsUpdate { struct RdsUpdate {
// The name to use in the CDS request. // The name to use in the CDS request.
std::string cluster_name; std::string cluster_name;
}; };
struct LdsUpdate {
// The name to use in the RDS request.
std::string route_config_name;
// The name to use in the CDS request. Present if the LDS response has it
// inlined.
Optional<RdsUpdate> rds_update;
};
using LdsUpdateMap = std::map<std::string /*server_name*/, LdsUpdate>; struct LdsUpdate {
// The name to use in the RDS request.
std::string route_config_name;
// The name to use in the CDS request. Present if the LDS response has it
// inlined.
Optional<RdsUpdate> rds_update;
};
using RdsUpdateMap = std::map<std::string /*route_config_name*/, RdsUpdate>; using LdsUpdateMap = std::map<std::string /*server_name*/, LdsUpdate>;
struct CdsUpdate { using RdsUpdateMap = std::map<std::string /*route_config_name*/, RdsUpdate>;
// The name to use in the EDS request.
// If empty, the cluster name will be used.
std::string eds_service_name;
// The LRS server to use for load reporting.
// If not set, load reporting will be disabled.
// If set to the empty string, will use the same server we obtained the CDS
// data from.
Optional<std::string> lrs_load_reporting_server_name;
};
using CdsUpdateMap = std::map<std::string /*cluster_name*/, CdsUpdate>; struct CdsUpdate {
// The name to use in the EDS request.
// If empty, the cluster name will be used.
std::string eds_service_name;
// The LRS server to use for load reporting.
// If not set, load reporting will be disabled.
// If set to the empty string, will use the same server we obtained the CDS
// data from.
Optional<std::string> lrs_load_reporting_server_name;
};
class XdsPriorityListUpdate { using CdsUpdateMap = std::map<std::string /*cluster_name*/, CdsUpdate>;
public:
struct LocalityMap {
struct Locality {
bool operator==(const Locality& other) const {
return *name == *other.name && serverlist == other.serverlist &&
lb_weight == other.lb_weight && priority == other.priority;
}
// This comparator only compares the locality names. class PriorityListUpdate {
struct Less { public:
bool operator()(const Locality& lhs, const Locality& rhs) const { struct LocalityMap {
return XdsLocalityName::Less()(lhs.name, rhs.name); struct Locality {
bool operator==(const Locality& other) const {
return *name == *other.name && serverlist == other.serverlist &&
lb_weight == other.lb_weight && priority == other.priority;
} }
// This comparator only compares the locality names.
struct Less {
bool operator()(const Locality& lhs, const Locality& rhs) const {
return XdsLocalityName::Less()(lhs.name, rhs.name);
}
};
RefCountedPtr<XdsLocalityName> name;
ServerAddressList serverlist;
uint32_t lb_weight;
uint32_t priority;
}; };
RefCountedPtr<XdsLocalityName> name; bool Contains(const RefCountedPtr<XdsLocalityName>& name) const {
ServerAddressList serverlist; return localities.find(name) != localities.end();
uint32_t lb_weight; }
uint32_t priority;
size_t size() const { return localities.size(); }
std::map<RefCountedPtr<XdsLocalityName>, Locality, XdsLocalityName::Less>
localities;
}; };
bool Contains(const RefCountedPtr<XdsLocalityName>& name) const { bool operator==(const PriorityListUpdate& other) const;
return localities.find(name) != localities.end(); bool operator!=(const PriorityListUpdate& other) const {
return !(*this == other);
} }
size_t size() const { return localities.size(); } void Add(LocalityMap::Locality locality);
std::map<RefCountedPtr<XdsLocalityName>, Locality, XdsLocalityName::Less> const LocalityMap* Find(uint32_t priority) const;
localities;
};
bool operator==(const XdsPriorityListUpdate& other) const; bool Contains(uint32_t priority) const {
bool operator!=(const XdsPriorityListUpdate& other) const { return priority < priorities_.size();
return !(*this == other); }
} bool Contains(const RefCountedPtr<XdsLocalityName>& name);
void Add(LocalityMap::Locality locality); bool empty() const { return priorities_.empty(); }
size_t size() const { return priorities_.size(); }
const LocalityMap* Find(uint32_t priority) const; // Callers should make sure the priority list is non-empty.
uint32_t LowestPriority() const {
return static_cast<uint32_t>(priorities_.size()) - 1;
}
bool Contains(uint32_t priority) const { private:
return priority < priorities_.size(); InlinedVector<LocalityMap, 2> priorities_;
} };
bool Contains(const RefCountedPtr<XdsLocalityName>& name);
bool empty() const { return priorities_.empty(); } // There are two phases of accessing this class's content:
size_t size() const { return priorities_.size(); } // 1. to initialize in the control plane combiner;
// 2. to use in the data plane combiner.
// So no additional synchronization is needed.
class DropConfig : public RefCounted<DropConfig> {
public:
struct DropCategory {
bool operator==(const DropCategory& other) const {
return name == other.name &&
parts_per_million == other.parts_per_million;
}
// Callers should make sure the priority list is non-empty. std::string name;
uint32_t LowestPriority() const { const uint32_t parts_per_million;
return static_cast<uint32_t>(priorities_.size()) - 1; };
}
private: using DropCategoryList = InlinedVector<DropCategory, 2>;
InlinedVector<LocalityMap, 2> priorities_;
};
// There are two phases of accessing this class's content: void AddCategory(std::string name, uint32_t parts_per_million) {
// 1. to initialize in the control plane combiner; drop_category_list_.emplace_back(
// 2. to use in the data plane combiner. DropCategory{std::move(name), parts_per_million});
// So no additional synchronization is needed.
class XdsDropConfig : public RefCounted<XdsDropConfig> {
public:
struct DropCategory {
bool operator==(const DropCategory& other) const {
return name == other.name && parts_per_million == other.parts_per_million;
} }
std::string name; // The only method invoked from the data plane combiner.
const uint32_t parts_per_million; bool ShouldDrop(const std::string** category_name) const;
};
using DropCategoryList = InlinedVector<DropCategory, 2>; const DropCategoryList& drop_category_list() const {
return drop_category_list_;
}
void AddCategory(std::string name, uint32_t parts_per_million) { bool operator==(const DropConfig& other) const {
drop_category_list_.emplace_back( return drop_category_list_ == other.drop_category_list_;
DropCategory{std::move(name), parts_per_million}); }
} bool operator!=(const DropConfig& other) const { return !(*this == other); }
// The only method invoked from the data plane combiner. private:
bool ShouldDrop(const std::string** category_name) const; DropCategoryList drop_category_list_;
};
const DropCategoryList& drop_category_list() const { struct EdsUpdate {
return drop_category_list_; PriorityListUpdate priority_list_update;
} RefCountedPtr<DropConfig> drop_config;
bool drop_all = false;
};
bool operator==(const XdsDropConfig& other) const { using EdsUpdateMap = std::map<std::string /*eds_service_name*/, EdsUpdate>;
return drop_category_list_ == other.drop_category_list_;
} XdsApi(const XdsBootstrap::Node* node, const char* build_version)
bool operator!=(const XdsDropConfig& other) const { : node_(node), build_version_(build_version) {}
return !(*this == other);
} // Creates a request to nack an unsupported resource type.
// Takes ownership of \a error.
grpc_slice CreateUnsupportedTypeNackRequest(const std::string& type_url,
const std::string& nonce,
grpc_error* error);
// Creates an LDS request querying \a server_name.
// Takes ownership of \a error.
grpc_slice CreateLdsRequest(const std::string& server_name,
const std::string& version,
const std::string& nonce, grpc_error* error,
bool populate_node);
// Creates an RDS request querying \a route_config_name.
// Takes ownership of \a error.
grpc_slice CreateRdsRequest(const std::string& route_config_name,
const std::string& version,
const std::string& nonce, grpc_error* error,
bool populate_node);
// Creates a CDS request querying \a cluster_names.
// Takes ownership of \a error.
grpc_slice CreateCdsRequest(const std::set<StringView>& cluster_names,
const std::string& version,
const std::string& nonce, grpc_error* error,
bool populate_node);
// Creates an EDS request querying \a eds_service_names.
// Takes ownership of \a error.
grpc_slice CreateEdsRequest(const std::set<StringView>& eds_service_names,
const std::string& version,
const std::string& nonce, grpc_error* error,
bool populate_node);
// Parses the ADS response and outputs the validated update for either CDS or
// EDS. If the response can't be parsed at the top level, \a type_url will
// point to an empty string; otherwise, it will point to the received data.
grpc_error* ParseAdsResponse(
const grpc_slice& encoded_response,
const std::string& expected_server_name,
const std::string& expected_route_config_name,
const std::set<StringView>& expected_eds_service_names,
LdsUpdate* lds_update, RdsUpdate* rds_update,
CdsUpdateMap* cds_update_map, EdsUpdateMap* eds_update_map,
std::string* version, std::string* nonce, std::string* type_url);
// Creates an LRS request querying \a server_name.
grpc_slice CreateLrsInitialRequest(const std::string& server_name);
// Creates an LRS request sending client-side load reports. If all the
// counters are zero, returns empty slice.
grpc_slice CreateLrsRequest(std::map<StringView /*cluster_name*/,
std::set<XdsClientStats*>, StringLess>
client_stats_map);
// Parses the LRS response and returns \a
// load_reporting_interval for client-side load reporting. If there is any
// error, the output config is invalid.
grpc_error* ParseLrsResponse(const grpc_slice& encoded_response,
std::set<std::string>* cluster_names,
grpc_millis* load_reporting_interval);
private: private:
DropCategoryList drop_category_list_; const XdsBootstrap::Node* node_;
const char* build_version_;
}; };
struct EdsUpdate {
XdsPriorityListUpdate priority_list_update;
RefCountedPtr<XdsDropConfig> drop_config;
bool drop_all = false;
};
using EdsUpdateMap = std::map<std::string /*eds_service_name*/, EdsUpdate>;
// Creates a request to nack an unsupported resource type.
// Takes ownership of \a error.
grpc_slice XdsUnsupportedTypeNackRequestCreateAndEncode(
const std::string& type_url, const std::string& nonce, grpc_error* error);
// Creates an LDS request querying \a server_name.
// Takes ownership of \a error.
grpc_slice XdsLdsRequestCreateAndEncode(const std::string& server_name,
const XdsBootstrap::Node* node,
const char* build_version,
const std::string& version,
const std::string& nonce,
grpc_error* error);
// Creates an RDS request querying \a route_config_name.
// Takes ownership of \a error.
grpc_slice XdsRdsRequestCreateAndEncode(const std::string& route_config_name,
const XdsBootstrap::Node* node,
const char* build_version,
const std::string& version,
const std::string& nonce,
grpc_error* error);
// Creates a CDS request querying \a cluster_names.
// Takes ownership of \a error.
grpc_slice XdsCdsRequestCreateAndEncode(
const std::set<StringView>& cluster_names, const XdsBootstrap::Node* node,
const char* build_version, const std::string& version,
const std::string& nonce, grpc_error* error);
// Creates an EDS request querying \a eds_service_names.
// Takes ownership of \a error.
grpc_slice XdsEdsRequestCreateAndEncode(
const std::set<StringView>& eds_service_names,
const XdsBootstrap::Node* node, const char* build_version,
const std::string& version, const std::string& nonce, grpc_error* error);
// Parses the ADS response and outputs the validated update for either CDS or
// EDS. If the response can't be parsed at the top level, \a type_url will point
// to an empty string; otherwise, it will point to the received data.
grpc_error* XdsAdsResponseDecodeAndParse(
const grpc_slice& encoded_response, const std::string& expected_server_name,
const std::string& expected_route_config_name,
const std::set<StringView>& expected_eds_service_names,
LdsUpdate* lds_update, RdsUpdate* rds_update, CdsUpdateMap* cds_update_map,
EdsUpdateMap* eds_update_map, std::string* version, std::string* nonce,
std::string* type_url);
// Creates an LRS request querying \a server_name.
grpc_slice XdsLrsRequestCreateAndEncode(const std::string& server_name,
const XdsBootstrap::Node* node,
const char* build_version);
// Creates an LRS request sending client-side load reports. If all the counters
// are zero, returns empty slice.
grpc_slice XdsLrsRequestCreateAndEncode(
std::map<StringView /*cluster_name*/, std::set<XdsClientStats*>, StringLess>
client_stats_map);
// Parses the LRS response and returns \a
// load_reporting_interval for client-side load reporting. If there is any
// error, the output config is invalid.
grpc_error* XdsLrsResponseDecodeAndParse(const grpc_slice& encoded_response,
std::set<std::string>* cluster_names,
grpc_millis* load_reporting_interval);
} // namespace grpc_core } // namespace grpc_core
#endif /* GRPC_CORE_EXT_FILTERS_CLIENT_CHANNEL_XDS_XDS_API_H */ #endif /* GRPC_CORE_EXT_FILTERS_CLIENT_CHANNEL_XDS_XDS_API_H */

@ -187,17 +187,18 @@ class XdsClient::ChannelState::AdsCallState
gpr_log(GPR_INFO, "[xds_client %p] %s", gpr_log(GPR_INFO, "[xds_client %p] %s",
self->ads_calld_->xds_client(), grpc_error_string(error)); self->ads_calld_->xds_client(), grpc_error_string(error));
} }
if (self->type_url_ == kLdsTypeUrl || self->type_url_ == kRdsTypeUrl) { if (self->type_url_ == XdsApi::kLdsTypeUrl ||
self->type_url_ == XdsApi::kRdsTypeUrl) {
self->ads_calld_->xds_client()->service_config_watcher_->OnError( self->ads_calld_->xds_client()->service_config_watcher_->OnError(
error); error);
} else if (self->type_url_ == kCdsTypeUrl) { } else if (self->type_url_ == XdsApi::kCdsTypeUrl) {
ClusterState& state = ClusterState& state =
self->ads_calld_->xds_client()->cluster_map_[self->name_]; self->ads_calld_->xds_client()->cluster_map_[self->name_];
for (const auto& p : state.watchers) { for (const auto& p : state.watchers) {
p.first->OnError(GRPC_ERROR_REF(error)); p.first->OnError(GRPC_ERROR_REF(error));
} }
GRPC_ERROR_UNREF(error); GRPC_ERROR_UNREF(error);
} else if (self->type_url_ == kEdsTypeUrl) { } else if (self->type_url_ == XdsApi::kEdsTypeUrl) {
EndpointState& state = EndpointState& state =
self->ads_calld_->xds_client()->endpoint_map_[self->name_]; self->ads_calld_->xds_client()->endpoint_map_[self->name_];
for (const auto& p : state.watchers) { for (const auto& p : state.watchers) {
@ -237,10 +238,10 @@ class XdsClient::ChannelState::AdsCallState
void SendMessageLocked(const std::string& type_url); void SendMessageLocked(const std::string& type_url);
void AcceptLdsUpdate(LdsUpdate lds_update); void AcceptLdsUpdate(XdsApi::LdsUpdate lds_update);
void AcceptRdsUpdate(RdsUpdate rds_update); void AcceptRdsUpdate(XdsApi::RdsUpdate rds_update);
void AcceptCdsUpdate(CdsUpdateMap cds_update_map); void AcceptCdsUpdate(XdsApi::CdsUpdateMap cds_update_map);
void AcceptEdsUpdate(EdsUpdateMap eds_update_map); void AcceptEdsUpdate(XdsApi::EdsUpdateMap eds_update_map);
static void OnRequestSent(void* arg, grpc_error* error); static void OnRequestSent(void* arg, grpc_error* error);
static void OnRequestSentLocked(void* arg, grpc_error* error); static void OnRequestSentLocked(void* arg, grpc_error* error);
@ -710,13 +711,13 @@ XdsClient::ChannelState::AdsCallState::AdsCallState(
GRPC_CLOSURE_INIT(&on_request_sent_, OnRequestSent, this, GRPC_CLOSURE_INIT(&on_request_sent_, OnRequestSent, this,
grpc_schedule_on_exec_ctx); grpc_schedule_on_exec_ctx);
if (xds_client()->service_config_watcher_ != nullptr) { if (xds_client()->service_config_watcher_ != nullptr) {
Subscribe(kLdsTypeUrl, xds_client()->server_name_); Subscribe(XdsApi::kLdsTypeUrl, xds_client()->server_name_);
} }
for (const auto& p : xds_client()->cluster_map_) { for (const auto& p : xds_client()->cluster_map_) {
Subscribe(kCdsTypeUrl, std::string(p.first)); Subscribe(XdsApi::kCdsTypeUrl, std::string(p.first));
} }
for (const auto& p : xds_client()->endpoint_map_) { for (const auto& p : xds_client()->endpoint_map_) {
Subscribe(kEdsTypeUrl, std::string(p.first)); Subscribe(XdsApi::kEdsTypeUrl, std::string(p.first));
} }
// Op: recv initial metadata. // Op: recv initial metadata.
op = ops; op = ops;
@ -789,35 +790,31 @@ void XdsClient::ChannelState::AdsCallState::SendMessageLocked(
auto& state = state_map_[type_url]; auto& state = state_map_[type_url];
grpc_error* error = state.error; grpc_error* error = state.error;
state.error = GRPC_ERROR_NONE; state.error = GRPC_ERROR_NONE;
const XdsBootstrap::Node* node =
sent_initial_message_ ? nullptr : xds_client()->bootstrap_->node();
const char* build_version =
sent_initial_message_ ? nullptr : xds_client()->build_version_.get();
sent_initial_message_ = true;
grpc_slice request_payload_slice; grpc_slice request_payload_slice;
if (type_url == kLdsTypeUrl) { if (type_url == XdsApi::kLdsTypeUrl) {
request_payload_slice = XdsLdsRequestCreateAndEncode( request_payload_slice = xds_client()->api_.CreateLdsRequest(
xds_client()->server_name_, node, build_version, state.version, xds_client()->server_name_, state.version, state.nonce, error,
state.nonce, error); !sent_initial_message_);
state.subscribed_resources[xds_client()->server_name_]->Start(Ref()); state.subscribed_resources[xds_client()->server_name_]->Start(Ref());
} else if (type_url == kRdsTypeUrl) { } else if (type_url == XdsApi::kRdsTypeUrl) {
request_payload_slice = XdsRdsRequestCreateAndEncode( request_payload_slice = xds_client()->api_.CreateRdsRequest(
xds_client()->route_config_name_, node, build_version, state.version, xds_client()->route_config_name_, state.version, state.nonce, error,
state.nonce, error); !sent_initial_message_);
state.subscribed_resources[xds_client()->route_config_name_]->Start(Ref()); state.subscribed_resources[xds_client()->route_config_name_]->Start(Ref());
} else if (type_url == kCdsTypeUrl) { } else if (type_url == XdsApi::kCdsTypeUrl) {
request_payload_slice = XdsCdsRequestCreateAndEncode( request_payload_slice = xds_client()->api_.CreateCdsRequest(
ClusterNamesForRequest(), node, build_version, state.version, ClusterNamesForRequest(), state.version, state.nonce, error,
state.nonce, error); !sent_initial_message_);
} else if (type_url == kEdsTypeUrl) { } else if (type_url == XdsApi::kEdsTypeUrl) {
request_payload_slice = XdsEdsRequestCreateAndEncode( request_payload_slice = xds_client()->api_.CreateEdsRequest(
EdsServiceNamesForRequest(), node, build_version, state.version, EdsServiceNamesForRequest(), state.version, state.nonce, error,
state.nonce, error); !sent_initial_message_);
} else { } else {
request_payload_slice = XdsUnsupportedTypeNackRequestCreateAndEncode( request_payload_slice = xds_client()->api_.CreateUnsupportedTypeNackRequest(
type_url, state.nonce, state.error); type_url, state.nonce, state.error);
state_map_.erase(type_url); state_map_.erase(type_url);
} }
sent_initial_message_ = true;
// Create message payload. // Create message payload.
send_message_payload_ = send_message_payload_ =
grpc_raw_byte_buffer_create(&request_payload_slice, 1); grpc_raw_byte_buffer_create(&request_payload_slice, 1);
@ -863,7 +860,7 @@ bool XdsClient::ChannelState::AdsCallState::HasSubscribedResources() const {
} }
void XdsClient::ChannelState::AdsCallState::AcceptLdsUpdate( void XdsClient::ChannelState::AdsCallState::AcceptLdsUpdate(
LdsUpdate lds_update) { XdsApi::LdsUpdate lds_update) {
const std::string& cluster_name = const std::string& cluster_name =
lds_update.rds_update.has_value() lds_update.rds_update.has_value()
? lds_update.rds_update.value().cluster_name ? lds_update.rds_update.value().cluster_name
@ -876,7 +873,7 @@ void XdsClient::ChannelState::AdsCallState::AcceptLdsUpdate(
xds_client(), lds_update.route_config_name.c_str(), xds_client(), lds_update.route_config_name.c_str(),
cluster_name.c_str()); cluster_name.c_str());
} }
auto& lds_state = state_map_[kLdsTypeUrl]; auto& lds_state = state_map_[XdsApi::kLdsTypeUrl];
auto& state = lds_state.subscribed_resources[xds_client()->server_name_]; auto& state = lds_state.subscribed_resources[xds_client()->server_name_];
if (state != nullptr) state->Finish(); if (state != nullptr) state->Finish();
// Ignore identical update. // Ignore identical update.
@ -906,19 +903,19 @@ void XdsClient::ChannelState::AdsCallState::AcceptLdsUpdate(
} }
} else { } else {
// Send RDS request for dynamic resolution. // Send RDS request for dynamic resolution.
Subscribe(kRdsTypeUrl, xds_client()->route_config_name_); Subscribe(XdsApi::kRdsTypeUrl, xds_client()->route_config_name_);
} }
} }
void XdsClient::ChannelState::AdsCallState::AcceptRdsUpdate( void XdsClient::ChannelState::AdsCallState::AcceptRdsUpdate(
RdsUpdate rds_update) { XdsApi::RdsUpdate rds_update) {
if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) {
gpr_log(GPR_INFO, gpr_log(GPR_INFO,
"[xds_client %p] RDS update received: " "[xds_client %p] RDS update received: "
"cluster_name=%s", "cluster_name=%s",
xds_client(), rds_update.cluster_name.c_str()); xds_client(), rds_update.cluster_name.c_str());
} }
auto& rds_state = state_map_[kRdsTypeUrl]; auto& rds_state = state_map_[XdsApi::kRdsTypeUrl];
auto& state = auto& state =
rds_state.subscribed_resources[xds_client()->route_config_name_]; rds_state.subscribed_resources[xds_client()->route_config_name_];
if (state != nullptr) state->Finish(); if (state != nullptr) state->Finish();
@ -945,11 +942,11 @@ void XdsClient::ChannelState::AdsCallState::AcceptRdsUpdate(
} }
void XdsClient::ChannelState::AdsCallState::AcceptCdsUpdate( void XdsClient::ChannelState::AdsCallState::AcceptCdsUpdate(
CdsUpdateMap cds_update_map) { XdsApi::CdsUpdateMap cds_update_map) {
auto& cds_state = state_map_[kCdsTypeUrl]; auto& cds_state = state_map_[XdsApi::kCdsTypeUrl];
for (auto& p : cds_update_map) { for (auto& p : cds_update_map) {
const char* cluster_name = p.first.c_str(); const char* cluster_name = p.first.c_str();
CdsUpdate& cds_update = p.second; XdsApi::CdsUpdate& cds_update = p.second;
auto& state = cds_state.subscribed_resources[cluster_name]; auto& state = cds_state.subscribed_resources[cluster_name];
if (state != nullptr) state->Finish(); if (state != nullptr) state->Finish();
if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) {
@ -987,11 +984,11 @@ void XdsClient::ChannelState::AdsCallState::AcceptCdsUpdate(
} }
void XdsClient::ChannelState::AdsCallState::AcceptEdsUpdate( void XdsClient::ChannelState::AdsCallState::AcceptEdsUpdate(
EdsUpdateMap eds_update_map) { XdsApi::EdsUpdateMap eds_update_map) {
auto& eds_state = state_map_[kEdsTypeUrl]; auto& eds_state = state_map_[XdsApi::kEdsTypeUrl];
for (auto& p : eds_update_map) { for (auto& p : eds_update_map) {
const char* eds_service_name = p.first.c_str(); const char* eds_service_name = p.first.c_str();
EdsUpdate& eds_update = p.second; XdsApi::EdsUpdate& eds_update = p.second;
auto& state = eds_state.subscribed_resources[eds_service_name]; auto& state = eds_state.subscribed_resources[eds_service_name];
if (state != nullptr) state->Finish(); if (state != nullptr) state->Finish();
if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) { if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_client_trace)) {
@ -1015,9 +1012,9 @@ void XdsClient::ChannelState::AdsCallState::AcceptEdsUpdate(
const auto& locality = p.second; const auto& locality = p.second;
gpr_log(GPR_INFO, gpr_log(GPR_INFO,
"[xds_client %p] Priority %" PRIuPTR ", locality %" PRIuPTR "[xds_client %p] Priority %" PRIuPTR ", locality %" PRIuPTR
" %s contains %" PRIuPTR " server addresses", " %s has weight %d, contains %" PRIuPTR " server addresses",
xds_client(), priority, locality_count, xds_client(), priority, locality_count,
locality.name->AsHumanReadableString(), locality.name->AsHumanReadableString(), locality.lb_weight,
locality.serverlist.size()); locality.serverlist.size());
for (size_t i = 0; i < locality.serverlist.size(); ++i) { for (size_t i = 0; i < locality.serverlist.size(); ++i) {
char* ipport; char* ipport;
@ -1035,7 +1032,7 @@ void XdsClient::ChannelState::AdsCallState::AcceptEdsUpdate(
} }
for (size_t i = 0; for (size_t i = 0;
i < eds_update.drop_config->drop_category_list().size(); ++i) { i < eds_update.drop_config->drop_category_list().size(); ++i) {
const XdsDropConfig::DropCategory& drop_category = const XdsApi::DropConfig::DropCategory& drop_category =
eds_update.drop_config->drop_category_list()[i]; eds_update.drop_config->drop_category_list()[i];
gpr_log(GPR_INFO, gpr_log(GPR_INFO,
"[xds_client %p] Drop category %s has drop rate %d per million", "[xds_client %p] Drop category %s has drop rate %d per million",
@ -1046,7 +1043,7 @@ void XdsClient::ChannelState::AdsCallState::AcceptEdsUpdate(
EndpointState& endpoint_state = EndpointState& endpoint_state =
xds_client()->endpoint_map_[eds_service_name]; xds_client()->endpoint_map_[eds_service_name];
// Ignore identical update. // Ignore identical update.
const EdsUpdate& prev_update = endpoint_state.update; const XdsApi::EdsUpdate& prev_update = endpoint_state.update;
const bool priority_list_changed = const bool priority_list_changed =
prev_update.priority_list_update != eds_update.priority_list_update; prev_update.priority_list_update != eds_update.priority_list_update;
const bool drop_config_changed = const bool drop_config_changed =
@ -1138,15 +1135,15 @@ void XdsClient::ChannelState::AdsCallState::OnResponseReceivedLocked(
// mode. We will also need to cancel the timer when we receive a serverlist // mode. We will also need to cancel the timer when we receive a serverlist
// from the balancer. // from the balancer.
// Parse the response. // Parse the response.
LdsUpdate lds_update; XdsApi::LdsUpdate lds_update;
RdsUpdate rds_update; XdsApi::RdsUpdate rds_update;
CdsUpdateMap cds_update_map; XdsApi::CdsUpdateMap cds_update_map;
EdsUpdateMap eds_update_map; XdsApi::EdsUpdateMap eds_update_map;
std::string version; std::string version;
std::string nonce; std::string nonce;
std::string type_url; std::string type_url;
// Note that XdsAdsResponseDecodeAndParse() also validate the response. // Note that ParseAdsResponse() also validates the response.
grpc_error* parse_error = XdsAdsResponseDecodeAndParse( grpc_error* parse_error = xds_client->api_.ParseAdsResponse(
response_slice, xds_client->server_name_, xds_client->route_config_name_, response_slice, xds_client->server_name_, xds_client->route_config_name_,
ads_calld->EdsServiceNamesForRequest(), &lds_update, &rds_update, ads_calld->EdsServiceNamesForRequest(), &lds_update, &rds_update,
&cds_update_map, &eds_update_map, &version, &nonce, &type_url); &cds_update_map, &eds_update_map, &version, &nonce, &type_url);
@ -1173,13 +1170,13 @@ void XdsClient::ChannelState::AdsCallState::OnResponseReceivedLocked(
} else { } else {
ads_calld->seen_response_ = true; ads_calld->seen_response_ = true;
// Accept the ADS response according to the type_url. // Accept the ADS response according to the type_url.
if (type_url == kLdsTypeUrl) { if (type_url == XdsApi::kLdsTypeUrl) {
ads_calld->AcceptLdsUpdate(std::move(lds_update)); ads_calld->AcceptLdsUpdate(std::move(lds_update));
} else if (type_url == kRdsTypeUrl) { } else if (type_url == XdsApi::kRdsTypeUrl) {
ads_calld->AcceptRdsUpdate(std::move(rds_update)); ads_calld->AcceptRdsUpdate(std::move(rds_update));
} else if (type_url == kCdsTypeUrl) { } else if (type_url == XdsApi::kCdsTypeUrl) {
ads_calld->AcceptCdsUpdate(std::move(cds_update_map)); ads_calld->AcceptCdsUpdate(std::move(cds_update_map));
} else if (type_url == kEdsTypeUrl) { } else if (type_url == XdsApi::kEdsTypeUrl) {
ads_calld->AcceptEdsUpdate(std::move(eds_update_map)); ads_calld->AcceptEdsUpdate(std::move(eds_update_map));
} }
state.version = std::move(version); state.version = std::move(version);
@ -1258,7 +1255,7 @@ bool XdsClient::ChannelState::AdsCallState::IsCurrentCallOnChannel() const {
std::set<StringView> std::set<StringView>
XdsClient::ChannelState::AdsCallState::ClusterNamesForRequest() { XdsClient::ChannelState::AdsCallState::ClusterNamesForRequest() {
std::set<StringView> cluster_names; std::set<StringView> cluster_names;
for (auto& p : state_map_[kCdsTypeUrl].subscribed_resources) { for (auto& p : state_map_[XdsApi::kCdsTypeUrl].subscribed_resources) {
cluster_names.insert(p.first); cluster_names.insert(p.first);
OrphanablePtr<ResourceState>& state = p.second; OrphanablePtr<ResourceState>& state = p.second;
state->Start(Ref()); state->Start(Ref());
@ -1269,7 +1266,7 @@ XdsClient::ChannelState::AdsCallState::ClusterNamesForRequest() {
std::set<StringView> std::set<StringView>
XdsClient::ChannelState::AdsCallState::EdsServiceNamesForRequest() { XdsClient::ChannelState::AdsCallState::EdsServiceNamesForRequest() {
std::set<StringView> eds_names; std::set<StringView> eds_names;
for (auto& p : state_map_[kEdsTypeUrl].subscribed_resources) { for (auto& p : state_map_[XdsApi::kEdsTypeUrl].subscribed_resources) {
eds_names.insert(p.first); eds_names.insert(p.first);
OrphanablePtr<ResourceState>& state = p.second; OrphanablePtr<ResourceState>& state = p.second;
state->Start(Ref()); state->Start(Ref());
@ -1320,7 +1317,7 @@ void XdsClient::ChannelState::LrsCallState::Reporter::OnNextReportTimerLocked(
void XdsClient::ChannelState::LrsCallState::Reporter::SendReportLocked() { void XdsClient::ChannelState::LrsCallState::Reporter::SendReportLocked() {
// Create a request that contains the load report. // Create a request that contains the load report.
grpc_slice request_payload_slice = grpc_slice request_payload_slice =
XdsLrsRequestCreateAndEncode(xds_client()->ClientStatsMap()); xds_client()->api_.CreateLrsRequest(xds_client()->ClientStatsMap());
// Skip client load report if the counters were all zero in the last // Skip client load report if the counters were all zero in the last
// report and they are still zero in this one. // report and they are still zero in this one.
const bool old_val = last_report_counters_were_zero_; const bool old_val = last_report_counters_were_zero_;
@ -1396,9 +1393,8 @@ XdsClient::ChannelState::LrsCallState::LrsCallState(
nullptr, GRPC_MILLIS_INF_FUTURE, nullptr); nullptr, GRPC_MILLIS_INF_FUTURE, nullptr);
GPR_ASSERT(call_ != nullptr); GPR_ASSERT(call_ != nullptr);
// Init the request payload. // Init the request payload.
grpc_slice request_payload_slice = XdsLrsRequestCreateAndEncode( grpc_slice request_payload_slice =
xds_client()->server_name_, xds_client()->bootstrap_->node(), xds_client()->api_.CreateLrsInitialRequest(xds_client()->server_name_);
xds_client()->build_version_.get());
send_message_payload_ = send_message_payload_ =
grpc_raw_byte_buffer_create(&request_payload_slice, 1); grpc_raw_byte_buffer_create(&request_payload_slice, 1);
grpc_slice_unref_internal(request_payload_slice); grpc_slice_unref_internal(request_payload_slice);
@ -1577,7 +1573,7 @@ void XdsClient::ChannelState::LrsCallState::OnResponseReceivedLocked(
// Parse the response. // Parse the response.
std::set<std::string> new_cluster_names; std::set<std::string> new_cluster_names;
grpc_millis new_load_reporting_interval; grpc_millis new_load_reporting_interval;
grpc_error* parse_error = XdsLrsResponseDecodeAndParse( grpc_error* parse_error = xds_client->api_.ParseLrsResponse(
response_slice, &new_cluster_names, &new_load_reporting_interval); response_slice, &new_cluster_names, &new_load_reporting_interval);
if (parse_error != GRPC_ERROR_NONE) { if (parse_error != GRPC_ERROR_NONE) {
gpr_log(GPR_ERROR, gpr_log(GPR_ERROR,
@ -1722,6 +1718,8 @@ XdsClient::XdsClient(Combiner* combiner, grpc_pollset_set* interested_parties,
combiner_(GRPC_COMBINER_REF(combiner, "xds_client")), combiner_(GRPC_COMBINER_REF(combiner, "xds_client")),
interested_parties_(interested_parties), interested_parties_(interested_parties),
bootstrap_(XdsBootstrap::ReadFromFile(error)), bootstrap_(XdsBootstrap::ReadFromFile(error)),
api_(bootstrap_ == nullptr ? nullptr : bootstrap_->node(),
build_version_.get()),
server_name_(server_name), server_name_(server_name),
service_config_watcher_(std::move(watcher)) { service_config_watcher_(std::move(watcher)) {
if (*error != GRPC_ERROR_NONE) { if (*error != GRPC_ERROR_NONE) {
@ -1744,7 +1742,7 @@ XdsClient::XdsClient(Combiner* combiner, grpc_pollset_set* interested_parties,
chand_ = MakeOrphanable<ChannelState>( chand_ = MakeOrphanable<ChannelState>(
Ref(DEBUG_LOCATION, "XdsClient+ChannelState"), channel); Ref(DEBUG_LOCATION, "XdsClient+ChannelState"), channel);
if (service_config_watcher_ != nullptr) { if (service_config_watcher_ != nullptr) {
chand_->Subscribe(kLdsTypeUrl, std::string(server_name)); chand_->Subscribe(XdsApi::kLdsTypeUrl, std::string(server_name));
} }
} }
@ -1769,7 +1767,7 @@ void XdsClient::WatchClusterData(
if (cluster_state.update.has_value()) { if (cluster_state.update.has_value()) {
w->OnClusterChanged(cluster_state.update.value()); w->OnClusterChanged(cluster_state.update.value());
} }
chand_->Subscribe(kCdsTypeUrl, cluster_name_str); chand_->Subscribe(XdsApi::kCdsTypeUrl, cluster_name_str);
} }
void XdsClient::CancelClusterDataWatch(StringView cluster_name, void XdsClient::CancelClusterDataWatch(StringView cluster_name,
@ -1782,7 +1780,7 @@ void XdsClient::CancelClusterDataWatch(StringView cluster_name,
cluster_state.watchers.erase(it); cluster_state.watchers.erase(it);
if (cluster_state.watchers.empty()) { if (cluster_state.watchers.empty()) {
cluster_map_.erase(cluster_name_str); cluster_map_.erase(cluster_name_str);
chand_->Unsubscribe(kCdsTypeUrl, cluster_name_str); chand_->Unsubscribe(XdsApi::kCdsTypeUrl, cluster_name_str);
} }
} }
} }
@ -1799,7 +1797,7 @@ void XdsClient::WatchEndpointData(
if (!endpoint_state.update.priority_list_update.empty()) { if (!endpoint_state.update.priority_list_update.empty()) {
w->OnEndpointChanged(endpoint_state.update); w->OnEndpointChanged(endpoint_state.update);
} }
chand_->Subscribe(kEdsTypeUrl, eds_service_name_str); chand_->Subscribe(XdsApi::kEdsTypeUrl, eds_service_name_str);
} }
void XdsClient::CancelEndpointDataWatch(StringView eds_service_name, void XdsClient::CancelEndpointDataWatch(StringView eds_service_name,
@ -1812,7 +1810,7 @@ void XdsClient::CancelEndpointDataWatch(StringView eds_service_name,
endpoint_state.watchers.erase(it); endpoint_state.watchers.erase(it);
if (endpoint_state.watchers.empty()) { if (endpoint_state.watchers.empty()) {
endpoint_map_.erase(eds_service_name_str); endpoint_map_.erase(eds_service_name_str);
chand_->Unsubscribe(kEdsTypeUrl, eds_service_name_str); chand_->Unsubscribe(XdsApi::kEdsTypeUrl, eds_service_name_str);
} }
} }
} }

@ -56,7 +56,7 @@ class XdsClient : public InternallyRefCounted<XdsClient> {
public: public:
virtual ~ClusterWatcherInterface() = default; virtual ~ClusterWatcherInterface() = default;
virtual void OnClusterChanged(CdsUpdate cluster_data) = 0; virtual void OnClusterChanged(XdsApi::CdsUpdate cluster_data) = 0;
virtual void OnError(grpc_error* error) = 0; virtual void OnError(grpc_error* error) = 0;
}; };
@ -66,7 +66,7 @@ class XdsClient : public InternallyRefCounted<XdsClient> {
public: public:
virtual ~EndpointWatcherInterface() = default; virtual ~EndpointWatcherInterface() = default;
virtual void OnEndpointChanged(EdsUpdate update) = 0; virtual void OnEndpointChanged(XdsApi::EdsUpdate update) = 0;
virtual void OnError(grpc_error* error) = 0; virtual void OnError(grpc_error* error) = 0;
}; };
@ -175,7 +175,7 @@ class XdsClient : public InternallyRefCounted<XdsClient> {
std::map<ClusterWatcherInterface*, std::unique_ptr<ClusterWatcherInterface>> std::map<ClusterWatcherInterface*, std::unique_ptr<ClusterWatcherInterface>>
watchers; watchers;
// The latest data seen from CDS. // The latest data seen from CDS.
Optional<CdsUpdate> update; Optional<XdsApi::CdsUpdate> update;
}; };
struct EndpointState { struct EndpointState {
@ -184,7 +184,7 @@ class XdsClient : public InternallyRefCounted<XdsClient> {
watchers; watchers;
std::set<XdsClientStats*> client_stats; std::set<XdsClientStats*> client_stats;
// The latest data seen from EDS. // The latest data seen from EDS.
EdsUpdate update; XdsApi::EdsUpdate update;
}; };
// Sends an error notification to all watchers. // Sends an error notification to all watchers.
@ -212,6 +212,7 @@ class XdsClient : public InternallyRefCounted<XdsClient> {
grpc_pollset_set* interested_parties_; grpc_pollset_set* interested_parties_;
std::unique_ptr<XdsBootstrap> bootstrap_; std::unique_ptr<XdsBootstrap> bootstrap_;
XdsApi api_;
const std::string server_name_; const std::string server_name_;

@ -23,6 +23,7 @@
# each change must be ported from one to the other. # each change must be ported from one to the other.
# #
load("@rules_proto//proto:defs.bzl", "proto_library")
load( load(
"//bazel:generate_objc.bzl", "//bazel:generate_objc.bzl",
"generate_objc", "generate_objc",
@ -39,7 +40,7 @@ def proto_library_objc_wrapper(
"""proto_library for adding dependencies to google/protobuf protos """proto_library for adding dependencies to google/protobuf protos
use_well_known_protos - ignored in open source version use_well_known_protos - ignored in open source version
""" """
native.proto_library( proto_library(
name = name, name = name,
srcs = srcs, srcs = srcs,
deps = deps, deps = deps,

@ -14,6 +14,7 @@
licenses(["notice"]) # Apache v2 licenses(["notice"]) # Apache v2
load("@rules_proto//proto:defs.bzl", "proto_library")
load("//bazel:grpc_build_system.bzl", "grpc_package", "grpc_proto_library") load("//bazel:grpc_build_system.bzl", "grpc_package", "grpc_proto_library")
grpc_package( grpc_package(

@ -14,6 +14,8 @@
licenses(["notice"]) # Apache v2 licenses(["notice"]) # Apache v2
load("@rules_proto//proto:defs.bzl", "proto_library")
proto_library( proto_library(
name = "alts_handshaker_proto", name = "alts_handshaker_proto",
srcs = [ srcs = [

@ -14,6 +14,7 @@
licenses(["notice"]) # Apache v2 licenses(["notice"]) # Apache v2
load("@rules_proto//proto:defs.bzl", "proto_library")
load("//bazel:grpc_build_system.bzl", "grpc_package", "grpc_proto_library") load("//bazel:grpc_build_system.bzl", "grpc_package", "grpc_proto_library")
grpc_package( grpc_package(

@ -14,6 +14,7 @@
licenses(["notice"]) # Apache v2 licenses(["notice"]) # Apache v2
load("@rules_proto//proto:defs.bzl", "proto_library")
load("//bazel:grpc_build_system.bzl", "grpc_package", "grpc_proto_library") load("//bazel:grpc_build_system.bzl", "grpc_package", "grpc_proto_library")
grpc_package( grpc_package(

@ -14,6 +14,7 @@
licenses(["notice"]) # Apache v2 licenses(["notice"]) # Apache v2
load("@rules_proto//proto:defs.bzl", "proto_library")
load("//bazel:grpc_build_system.bzl", "grpc_package", "grpc_proto_library") load("//bazel:grpc_build_system.bzl", "grpc_package", "grpc_proto_library")
grpc_package( grpc_package(

@ -14,8 +14,9 @@
licenses(["notice"]) # Apache v2 licenses(["notice"]) # Apache v2
load("//bazel:grpc_build_system.bzl", "grpc_package", "grpc_proto_library")
load("@grpc_python_dependencies//:requirements.bzl", "requirement") load("@grpc_python_dependencies//:requirements.bzl", "requirement")
load("@rules_proto//proto:defs.bzl", "proto_library")
load("//bazel:grpc_build_system.bzl", "grpc_package", "grpc_proto_library")
load("//bazel:python_rules.bzl", "py_grpc_library", "py_proto_library") load("//bazel:python_rules.bzl", "py_grpc_library", "py_proto_library")
grpc_package( grpc_package(

@ -115,6 +115,9 @@ message SimpleResponse {
string server_id = 4; string server_id = 4;
// gRPCLB Path. // gRPCLB Path.
GrpclbRouteType grpclb_route_type = 5; GrpclbRouteType grpclb_route_type = 5;
// Server hostname.
string hostname = 6;
} }
// Client-streaming request. // Client-streaming request.
@ -190,3 +193,17 @@ message ReconnectInfo {
bool passed = 1; bool passed = 1;
repeated int32 backoff_ms = 2; repeated int32 backoff_ms = 2;
} }
message LoadBalancerStatsRequest {
// Request stats for the next num_rpcs sent by client.
int32 num_rpcs = 1;
// If num_rpcs have not completed within timeout_sec, return partial results.
int32 timeout_sec = 2;
}
message LoadBalancerStatsResponse {
// The number of completed RPCs for each peer.
map<string, int32> rpcs_by_peer = 1;
// The number of RPCs that failed to record a remote peer.
int32 num_failures = 2;
}

@ -1,3 +1,4 @@
load("@rules_proto//proto:defs.bzl", "proto_library")
load("@grpc_python_dependencies//:requirements.bzl", "requirement") load("@grpc_python_dependencies//:requirements.bzl", "requirement")
package(default_visibility = ["//visibility:public"]) package(default_visibility = ["//visibility:public"])

@ -77,3 +77,10 @@ service ReconnectService {
rpc Start(grpc.testing.ReconnectParams) returns (grpc.testing.Empty); rpc Start(grpc.testing.ReconnectParams) returns (grpc.testing.Empty);
rpc Stop(grpc.testing.Empty) returns (grpc.testing.ReconnectInfo); rpc Stop(grpc.testing.Empty) returns (grpc.testing.ReconnectInfo);
} }
// A service used to obtain stats for verifying LB behavior.
service LoadBalancerStatsService {
// Gets the backend distribution for RPCs sent by a test client.
rpc GetClientStats(LoadBalancerStatsRequest)
returns (LoadBalancerStatsResponse) {}
}

@ -367,7 +367,8 @@ cdef class _AioCall(GrpcCallWrapper):
"""Sends one single raw message in bytes.""" """Sends one single raw message in bytes."""
await _send_message(self, await _send_message(self,
message, message,
True, None,
False,
self._loop) self._loop)
async def send_receive_close(self): async def send_receive_close(self):

@ -66,7 +66,7 @@ cdef class CallbackWrapper:
cdef CallbackFailureHandler CQ_SHUTDOWN_FAILURE_HANDLER = CallbackFailureHandler( cdef CallbackFailureHandler CQ_SHUTDOWN_FAILURE_HANDLER = CallbackFailureHandler(
'grpc_completion_queue_shutdown', 'grpc_completion_queue_shutdown',
'Unknown', 'Unknown',
RuntimeError) InternalError)
cdef class CallbackCompletionQueue: cdef class CallbackCompletionQueue:
@ -153,12 +153,13 @@ async def _receive_message(GrpcCallWrapper grpc_call_wrapper,
async def _send_message(GrpcCallWrapper grpc_call_wrapper, async def _send_message(GrpcCallWrapper grpc_call_wrapper,
bytes message, bytes message,
bint metadata_sent, Operation send_initial_metadata_op,
int write_flag,
object loop): object loop):
cdef SendMessageOperation op = SendMessageOperation(message, _EMPTY_FLAG) cdef SendMessageOperation op = SendMessageOperation(message, write_flag)
cdef tuple ops = (op,) cdef tuple ops = (op,)
if not metadata_sent: if send_initial_metadata_op is not None:
ops = prepend_send_initial_metadata_op(ops, None) ops = (send_initial_metadata_op,) + ops
await execute_batch(grpc_call_wrapper, ops, loop) await execute_batch(grpc_call_wrapper, ops, loop)
@ -184,7 +185,7 @@ async def _send_error_status_from_server(GrpcCallWrapper grpc_call_wrapper,
grpc_status_code code, grpc_status_code code,
str details, str details,
tuple trailing_metadata, tuple trailing_metadata,
bint metadata_sent, Operation send_initial_metadata_op,
object loop): object loop):
assert code != StatusCode.ok, 'Expecting non-ok status code.' assert code != StatusCode.ok, 'Expecting non-ok status code.'
cdef SendStatusFromServerOperation op = SendStatusFromServerOperation( cdef SendStatusFromServerOperation op = SendStatusFromServerOperation(
@ -194,6 +195,6 @@ async def _send_error_status_from_server(GrpcCallWrapper grpc_call_wrapper,
_EMPTY_FLAGS, _EMPTY_FLAGS,
) )
cdef tuple ops = (op,) cdef tuple ops = (op,)
if not metadata_sent: if send_initial_metadata_op is not None:
ops = prepend_send_initial_metadata_op(ops, None) ops = (send_initial_metadata_op,) + ops
await execute_batch(grpc_call_wrapper, ops, loop) await execute_batch(grpc_call_wrapper, ops, loop)

@ -71,8 +71,7 @@ cdef class AioChannel:
other design of API if necessary. other design of API if necessary.
""" """
if self._status in (AIO_CHANNEL_STATUS_DESTROYED, AIO_CHANNEL_STATUS_CLOSING): if self._status in (AIO_CHANNEL_STATUS_DESTROYED, AIO_CHANNEL_STATUS_CLOSING):
# TODO(lidiz) switch to UsageError raise UsageError('Channel is closed.')
raise RuntimeError('Channel is closed.')
cdef gpr_timespec c_deadline = _timespec_from_time(deadline) cdef gpr_timespec c_deadline = _timespec_from_time(deadline)
@ -115,8 +114,7 @@ cdef class AioChannel:
The _AioCall object. The _AioCall object.
""" """
if self.closed(): if self.closed():
# TODO(lidiz) switch to UsageError raise UsageError('Channel is closed.')
raise RuntimeError('Channel is closed.')
cdef CallCredentials cython_call_credentials cdef CallCredentials cython_call_credentials
if python_call_credentials is not None: if python_call_credentials is not None:

@ -67,3 +67,33 @@ class _EOF:
EOF = _EOF() EOF = _EOF()
_COMPRESSION_METADATA_STRING_MAPPING = {
CompressionAlgorithm.none: 'identity',
CompressionAlgorithm.deflate: 'deflate',
CompressionAlgorithm.gzip: 'gzip',
}
class BaseError(Exception):
"""The base class for exceptions generated by gRPC AsyncIO stack."""
class UsageError(BaseError):
"""Raised when the usage of API by applications is inappropriate.
For example, trying to invoke RPC on a closed channel, mixing two styles
of streaming API on the client side. This exception should not be
suppressed.
"""
class AbortError(BaseError):
"""Raised when calling abort in servicer methods.
This exception should not be suppressed. Applications may catch it to
perform certain clean-up logic, and then re-raise it.
"""
class InternalError(BaseError):
"""Raised upon unexpected errors in native code."""

@ -31,10 +31,14 @@ cdef class RPCState(GrpcCallWrapper):
cdef grpc_status_code status_code cdef grpc_status_code status_code
cdef str status_details cdef str status_details
cdef tuple trailing_metadata cdef tuple trailing_metadata
cdef object compression_algorithm
cdef bint disable_next_compression
cdef bytes method(self) cdef bytes method(self)
cdef tuple invocation_metadata(self) cdef tuple invocation_metadata(self)
cdef void raise_for_termination(self) except * cdef void raise_for_termination(self) except *
cdef int get_write_flag(self)
cdef Operation create_send_initial_metadata_op_if_not_sent(self)
cdef enum AioServerStatus: cdef enum AioServerStatus:

@ -21,13 +21,23 @@ cdef int _EMPTY_FLAG = 0
cdef str _RPC_FINISHED_DETAILS = 'RPC already finished.' cdef str _RPC_FINISHED_DETAILS = 'RPC already finished.'
cdef str _SERVER_STOPPED_DETAILS = 'Server already stopped.' cdef str _SERVER_STOPPED_DETAILS = 'Server already stopped.'
cdef _augment_metadata(tuple metadata, object compression):
if compression is None:
return metadata
else:
return ((
GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY,
_COMPRESSION_METADATA_STRING_MAPPING[compression]
),) + metadata
cdef class _HandlerCallDetails: cdef class _HandlerCallDetails:
def __cinit__(self, str method, tuple invocation_metadata): def __cinit__(self, str method, tuple invocation_metadata):
self.method = method self.method = method
self.invocation_metadata = invocation_metadata self.invocation_metadata = invocation_metadata
class _ServerStoppedError(RuntimeError): class _ServerStoppedError(BaseError):
"""Raised if the server is stopped.""" """Raised if the server is stopped."""
@ -45,6 +55,8 @@ cdef class RPCState:
self.status_code = StatusCode.ok self.status_code = StatusCode.ok
self.status_details = '' self.status_details = ''
self.trailing_metadata = _IMMUTABLE_EMPTY_METADATA self.trailing_metadata = _IMMUTABLE_EMPTY_METADATA
self.compression_algorithm = None
self.disable_next_compression = False
cdef bytes method(self): cdef bytes method(self):
return _slice_bytes(self.details.method) return _slice_bytes(self.details.method)
@ -65,10 +77,28 @@ cdef class RPCState:
if self.abort_exception is not None: if self.abort_exception is not None:
raise self.abort_exception raise self.abort_exception
if self.status_sent: if self.status_sent:
raise RuntimeError(_RPC_FINISHED_DETAILS) raise UsageError(_RPC_FINISHED_DETAILS)
if self.server._status == AIO_SERVER_STATUS_STOPPED: if self.server._status == AIO_SERVER_STATUS_STOPPED:
raise _ServerStoppedError(_SERVER_STOPPED_DETAILS) raise _ServerStoppedError(_SERVER_STOPPED_DETAILS)
cdef int get_write_flag(self):
if self.disable_next_compression:
self.disable_next_compression = False
return WriteFlag.no_compress
else:
return _EMPTY_FLAG
cdef Operation create_send_initial_metadata_op_if_not_sent(self):
cdef SendInitialMetadataOperation op
if self.metadata_sent:
return None
else:
op = SendInitialMetadataOperation(
_augment_metadata(_IMMUTABLE_EMPTY_METADATA, self.compression_algorithm),
_EMPTY_FLAG
)
return op
def __dealloc__(self): def __dealloc__(self):
"""Cleans the Core objects.""" """Cleans the Core objects."""
grpc_call_details_destroy(&self.details) grpc_call_details_destroy(&self.details)
@ -77,11 +107,6 @@ cdef class RPCState:
grpc_call_unref(self.call) grpc_call_unref(self.call)
# TODO(lidiz) inherit this from Python level `AioRpcStatus`, we need to improve
# current code structure to make it happen.
class AbortError(Exception): pass
cdef class _ServicerContext: cdef class _ServicerContext:
cdef RPCState _rpc_state cdef RPCState _rpc_state
cdef object _loop cdef object _loop
@ -116,18 +141,23 @@ cdef class _ServicerContext:
await _send_message(self._rpc_state, await _send_message(self._rpc_state,
serialize(self._response_serializer, message), serialize(self._response_serializer, message),
self._rpc_state.metadata_sent, self._rpc_state.create_send_initial_metadata_op_if_not_sent(),
self._rpc_state.get_write_flag(),
self._loop) self._loop)
if not self._rpc_state.metadata_sent: self._rpc_state.metadata_sent = True
self._rpc_state.metadata_sent = True
async def send_initial_metadata(self, tuple metadata): async def send_initial_metadata(self, tuple metadata):
self._rpc_state.raise_for_termination() self._rpc_state.raise_for_termination()
if self._rpc_state.metadata_sent: if self._rpc_state.metadata_sent:
raise RuntimeError('Send initial metadata failed: already sent') raise UsageError('Send initial metadata failed: already sent')
else: else:
await _send_initial_metadata(self._rpc_state, metadata, _EMPTY_FLAG, self._loop) await _send_initial_metadata(
self._rpc_state,
_augment_metadata(metadata, self._rpc_state.compression_algorithm),
_EMPTY_FLAG,
self._loop
)
self._rpc_state.metadata_sent = True self._rpc_state.metadata_sent = True
async def abort(self, async def abort(self,
@ -135,7 +165,7 @@ cdef class _ServicerContext:
str details='', str details='',
tuple trailing_metadata=_IMMUTABLE_EMPTY_METADATA): tuple trailing_metadata=_IMMUTABLE_EMPTY_METADATA):
if self._rpc_state.abort_exception is not None: if self._rpc_state.abort_exception is not None:
raise RuntimeError('Abort already called!') raise UsageError('Abort already called!')
else: else:
# Keeps track of the exception object. After abort happen, the RPC # Keeps track of the exception object. After abort happen, the RPC
# should stop execution. However, if users decided to suppress it, it # should stop execution. However, if users decided to suppress it, it
@ -156,7 +186,7 @@ cdef class _ServicerContext:
actual_code, actual_code,
details, details,
trailing_metadata, trailing_metadata,
self._rpc_state.metadata_sent, self._rpc_state.create_send_initial_metadata_op_if_not_sent(),
self._loop self._loop
) )
@ -174,6 +204,15 @@ cdef class _ServicerContext:
def set_details(self, str details): def set_details(self, str details):
self._rpc_state.status_details = details self._rpc_state.status_details = details
def set_compression(self, object compression):
if self._rpc_state.metadata_sent:
raise RuntimeError('Compression setting must be specified before sending initial metadata')
else:
self._rpc_state.compression_algorithm = compression
def disable_next_message_compression(self):
self._rpc_state.disable_next_compression = True
cdef _find_method_handler(str method, tuple metadata, list generic_handlers): cdef _find_method_handler(str method, tuple metadata, list generic_handlers):
cdef _HandlerCallDetails handler_call_details = _HandlerCallDetails(method, cdef _HandlerCallDetails handler_call_details = _HandlerCallDetails(method,
@ -217,7 +256,7 @@ async def _finish_handler_with_unary_response(RPCState rpc_state,
# Assembles the batch operations # Assembles the batch operations
cdef tuple finish_ops cdef tuple finish_ops
finish_ops = ( finish_ops = (
SendMessageOperation(response_raw, _EMPTY_FLAGS), SendMessageOperation(response_raw, rpc_state.get_write_flag()),
SendStatusFromServerOperation( SendStatusFromServerOperation(
rpc_state.trailing_metadata, rpc_state.trailing_metadata,
rpc_state.status_code, rpc_state.status_code,
@ -446,7 +485,7 @@ async def _handle_exceptions(RPCState rpc_state, object rpc_coro, object loop):
status_code, status_code,
'Unexpected %s: %s' % (type(e), e), 'Unexpected %s: %s' % (type(e), e),
rpc_state.trailing_metadata, rpc_state.trailing_metadata,
rpc_state.metadata_sent, rpc_state.create_send_initial_metadata_op_if_not_sent(),
loop loop
) )
@ -492,7 +531,7 @@ async def _handle_rpc(list generic_handlers, RPCState rpc_state, object loop):
StatusCode.unimplemented, StatusCode.unimplemented,
'Method not found!', 'Method not found!',
_IMMUTABLE_EMPTY_METADATA, _IMMUTABLE_EMPTY_METADATA,
rpc_state.metadata_sent, rpc_state.create_send_initial_metadata_op_if_not_sent(),
loop loop
) )
return return
@ -535,13 +574,13 @@ cdef CallbackFailureHandler REQUEST_CALL_FAILURE_HANDLER = CallbackFailureHandle
cdef CallbackFailureHandler SERVER_SHUTDOWN_FAILURE_HANDLER = CallbackFailureHandler( cdef CallbackFailureHandler SERVER_SHUTDOWN_FAILURE_HANDLER = CallbackFailureHandler(
'grpc_server_shutdown_and_notify', 'grpc_server_shutdown_and_notify',
None, None,
RuntimeError) InternalError)
cdef class AioServer: cdef class AioServer:
def __init__(self, loop, thread_pool, generic_handlers, interceptors, def __init__(self, loop, thread_pool, generic_handlers, interceptors,
options, maximum_concurrent_rpcs, compression): options, maximum_concurrent_rpcs):
# NOTE(lidiz) Core objects won't be deallocated automatically. # NOTE(lidiz) Core objects won't be deallocated automatically.
# If AioServer.shutdown is not called, those objects will leak. # If AioServer.shutdown is not called, those objects will leak.
self._server = Server(options) self._server = Server(options)
@ -570,8 +609,6 @@ cdef class AioServer:
raise NotImplementedError() raise NotImplementedError()
if maximum_concurrent_rpcs: if maximum_concurrent_rpcs:
raise NotImplementedError() raise NotImplementedError()
if compression:
raise NotImplementedError()
if thread_pool: if thread_pool:
raise NotImplementedError() raise NotImplementedError()
@ -600,7 +637,7 @@ cdef class AioServer:
wrapper.c_functor() wrapper.c_functor()
) )
if error != GRPC_CALL_OK: if error != GRPC_CALL_OK:
raise RuntimeError("Error in grpc_server_request_call: %s" % error) raise InternalError("Error in grpc_server_request_call: %s" % error)
await future await future
return rpc_state return rpc_state
@ -650,7 +687,7 @@ cdef class AioServer:
if self._status == AIO_SERVER_STATUS_RUNNING: if self._status == AIO_SERVER_STATUS_RUNNING:
return return
elif self._status != AIO_SERVER_STATUS_READY: elif self._status != AIO_SERVER_STATUS_READY:
raise RuntimeError('Server not in ready state') raise UsageError('Server not in ready state')
self._status = AIO_SERVER_STATUS_RUNNING self._status = AIO_SERVER_STATUS_RUNNING
cdef object server_started = self._loop.create_future() cdef object server_started = self._loop.create_future()
@ -746,11 +783,7 @@ cdef class AioServer:
return True return True
def __dealloc__(self): def __dealloc__(self):
"""Deallocation of Core objects are ensured by Python grpc.aio.Server. """Deallocation of Core objects are ensured by Python layer."""
If the Cython representation is deallocated without underlying objects
freed, raise an RuntimeError.
"""
# TODO(lidiz) if users create server, and then dealloc it immediately. # TODO(lidiz) if users create server, and then dealloc it immediately.
# There is a potential memory leak of created Core server. # There is a potential memory leak of created Core server.
if self._status != AIO_SERVER_STATUS_STOPPED: if self._status != AIO_SERVER_STATUS_STOPPED:

@ -118,7 +118,7 @@ cdef class Server:
def cancel_all_calls(self): def cancel_all_calls(self):
if not self.is_shutting_down: if not self.is_shutting_down:
raise RuntimeError("the server must be shutting down to cancel all calls") raise UsageError("the server must be shutting down to cancel all calls")
elif self.is_shutdown: elif self.is_shutdown:
return return
else: else:
@ -136,7 +136,7 @@ cdef class Server:
pass pass
elif not self.is_shutting_down: elif not self.is_shutting_down:
if self.backup_shutdown_queue is None: if self.backup_shutdown_queue is None:
raise RuntimeError('Server shutdown failed: no completion queue.') raise InternalError('Server shutdown failed: no completion queue.')
else: else:
# the user didn't call shutdown - use our backup queue # the user didn't call shutdown - use our backup queue
self._c_shutdown(self.backup_shutdown_queue, None) self._c_shutdown(self.backup_shutdown_queue, None)

@ -17,12 +17,11 @@ gRPC Async API objects may only be used on the thread on which they were
created. AsyncIO doesn't provide thread safety for most of its APIs. created. AsyncIO doesn't provide thread safety for most of its APIs.
""" """
import abc from typing import Any, Optional, Sequence, Tuple
from typing import Any, Optional, Sequence, Text, Tuple
import six
import grpc import grpc
from grpc._cython.cygrpc import EOF, AbortError, init_grpc_aio from grpc._cython.cygrpc import (EOF, AbortError, BaseError, UsageError,
init_grpc_aio)
from ._base_call import Call, RpcContext, UnaryStreamCall, UnaryUnaryCall from ._base_call import Call, RpcContext, UnaryStreamCall, UnaryUnaryCall
from ._call import AioRpcError from ._call import AioRpcError
@ -34,7 +33,7 @@ from ._typing import ChannelArgumentType
def insecure_channel( def insecure_channel(
target: Text, target: str,
options: Optional[ChannelArgumentType] = None, options: Optional[ChannelArgumentType] = None,
compression: Optional[grpc.Compression] = None, compression: Optional[grpc.Compression] = None,
interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] = None): interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] = None):
@ -57,7 +56,7 @@ def insecure_channel(
def secure_channel( def secure_channel(
target: Text, target: str,
credentials: grpc.ChannelCredentials, credentials: grpc.ChannelCredentials,
options: Optional[ChannelArgumentType] = None, options: Optional[ChannelArgumentType] = None,
compression: Optional[grpc.Compression] = None, compression: Optional[grpc.Compression] = None,
@ -88,4 +87,4 @@ __all__ = ('AioRpcError', 'RpcContext', 'Call', 'UnaryUnaryCall',
'UnaryUnaryMultiCallable', 'ClientCallDetails', 'UnaryUnaryMultiCallable', 'ClientCallDetails',
'UnaryUnaryClientInterceptor', 'InterceptedUnaryUnaryCall', 'UnaryUnaryClientInterceptor', 'InterceptedUnaryUnaryCall',
'insecure_channel', 'server', 'Server', 'EOF', 'secure_channel', 'insecure_channel', 'server', 'Server', 'EOF', 'secure_channel',
'AbortError') 'AbortError', 'BaseError', 'UsageError')

@ -19,7 +19,7 @@ RPC, e.g. cancellation.
""" """
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from typing import AsyncIterable, Awaitable, Generic, Optional, Text, Union from typing import AsyncIterable, Awaitable, Generic, Optional, Union
import grpc import grpc
@ -110,7 +110,7 @@ class Call(RpcContext, metaclass=ABCMeta):
""" """
@abstractmethod @abstractmethod
async def details(self) -> Text: async def details(self) -> str:
"""Accesses the details sent by the server. """Accesses the details sent by the server.
Returns: Returns:

@ -16,6 +16,7 @@
import asyncio import asyncio
from functools import partial from functools import partial
import logging import logging
import enum
from typing import AsyncIterable, Awaitable, Dict, Optional from typing import AsyncIterable, Awaitable, Dict, Optional
import grpc import grpc
@ -143,9 +144,13 @@ class AioRpcError(grpc.RpcError):
def _create_rpc_error(initial_metadata: Optional[MetadataType], def _create_rpc_error(initial_metadata: Optional[MetadataType],
status: cygrpc.AioRpcStatus) -> AioRpcError: status: cygrpc.AioRpcStatus) -> AioRpcError:
return AioRpcError(_common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()], return AioRpcError(
status.details(), initial_metadata, _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()],
status.trailing_metadata()) status.details(),
initial_metadata,
status.trailing_metadata(),
status.debug_error_string(),
)
class Call: class Call:
@ -234,6 +239,12 @@ class Call:
return self._repr() return self._repr()
class _APIStyle(enum.IntEnum):
UNKNOWN = 0
ASYNC_GENERATOR = 1
READER_WRITER = 2
class _UnaryResponseMixin(Call): class _UnaryResponseMixin(Call):
_call_response: asyncio.Task _call_response: asyncio.Task
@ -279,10 +290,19 @@ class _UnaryResponseMixin(Call):
class _StreamResponseMixin(Call): class _StreamResponseMixin(Call):
_message_aiter: AsyncIterable[ResponseType] _message_aiter: AsyncIterable[ResponseType]
_preparation: asyncio.Task _preparation: asyncio.Task
_response_style: _APIStyle
def _init_stream_response_mixin(self, preparation: asyncio.Task): def _init_stream_response_mixin(self, preparation: asyncio.Task):
self._message_aiter = None self._message_aiter = None
self._preparation = preparation self._preparation = preparation
self._response_style = _APIStyle.UNKNOWN
def _update_response_style(self, style: _APIStyle):
if self._response_style is _APIStyle.UNKNOWN:
self._response_style = style
elif self._response_style is not style:
raise cygrpc.UsageError(
'Please don\'t mix two styles of API for streaming responses')
def cancel(self) -> bool: def cancel(self) -> bool:
if super().cancel(): if super().cancel():
@ -298,6 +318,7 @@ class _StreamResponseMixin(Call):
message = await self._read() message = await self._read()
def __aiter__(self) -> AsyncIterable[ResponseType]: def __aiter__(self) -> AsyncIterable[ResponseType]:
self._update_response_style(_APIStyle.ASYNC_GENERATOR)
if self._message_aiter is None: if self._message_aiter is None:
self._message_aiter = self._fetch_stream_responses() self._message_aiter = self._fetch_stream_responses()
return self._message_aiter return self._message_aiter
@ -324,6 +345,7 @@ class _StreamResponseMixin(Call):
if self.done(): if self.done():
await self._raise_for_status() await self._raise_for_status()
return cygrpc.EOF return cygrpc.EOF
self._update_response_style(_APIStyle.READER_WRITER)
response_message = await self._read() response_message = await self._read()
@ -335,20 +357,28 @@ class _StreamResponseMixin(Call):
class _StreamRequestMixin(Call): class _StreamRequestMixin(Call):
_metadata_sent: asyncio.Event _metadata_sent: asyncio.Event
_done_writing: bool _done_writing_flag: bool
_async_request_poller: Optional[asyncio.Task] _async_request_poller: Optional[asyncio.Task]
_request_style: _APIStyle
def _init_stream_request_mixin( def _init_stream_request_mixin(
self, request_async_iterator: Optional[AsyncIterable[RequestType]]): self, request_async_iterator: Optional[AsyncIterable[RequestType]]):
self._metadata_sent = asyncio.Event(loop=self._loop) self._metadata_sent = asyncio.Event(loop=self._loop)
self._done_writing = False self._done_writing_flag = False
# If user passes in an async iterator, create a consumer Task. # If user passes in an async iterator, create a consumer Task.
if request_async_iterator is not None: if request_async_iterator is not None:
self._async_request_poller = self._loop.create_task( self._async_request_poller = self._loop.create_task(
self._consume_request_iterator(request_async_iterator)) self._consume_request_iterator(request_async_iterator))
self._request_style = _APIStyle.ASYNC_GENERATOR
else: else:
self._async_request_poller = None self._async_request_poller = None
self._request_style = _APIStyle.READER_WRITER
def _raise_for_different_style(self, style: _APIStyle):
if self._request_style is not style:
raise cygrpc.UsageError(
'Please don\'t mix two styles of API for streaming requests')
def cancel(self) -> bool: def cancel(self) -> bool:
if super().cancel(): if super().cancel():
@ -365,8 +395,8 @@ class _StreamRequestMixin(Call):
self, request_async_iterator: AsyncIterable[RequestType]) -> None: self, request_async_iterator: AsyncIterable[RequestType]) -> None:
try: try:
async for request in request_async_iterator: async for request in request_async_iterator:
await self.write(request) await self._write(request)
await self.done_writing() await self._done_writing()
except AioRpcError as rpc_error: except AioRpcError as rpc_error:
# Rpc status should be exposed through other API. Exceptions raised # Rpc status should be exposed through other API. Exceptions raised
# within this Task won't be retrieved by another coroutine. It's # within this Task won't be retrieved by another coroutine. It's
@ -374,10 +404,10 @@ class _StreamRequestMixin(Call):
_LOGGER.debug('Exception while consuming the request_iterator: %s', _LOGGER.debug('Exception while consuming the request_iterator: %s',
rpc_error) rpc_error)
async def write(self, request: RequestType) -> None: async def _write(self, request: RequestType) -> None:
if self.done(): if self.done():
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
if self._done_writing: if self._done_writing_flag:
raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS) raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
if not self._metadata_sent.is_set(): if not self._metadata_sent.is_set():
await self._metadata_sent.wait() await self._metadata_sent.wait()
@ -394,14 +424,13 @@ class _StreamRequestMixin(Call):
self.cancel() self.cancel()
await self._raise_for_status() await self._raise_for_status()
async def done_writing(self) -> None: async def _done_writing(self) -> None:
"""Implementation of done_writing is idempotent."""
if self.done(): if self.done():
# If the RPC is finished, do nothing. # If the RPC is finished, do nothing.
return return
if not self._done_writing: if not self._done_writing_flag:
# If the done writing is not sent before, try to send it. # If the done writing is not sent before, try to send it.
self._done_writing = True self._done_writing_flag = True
try: try:
await self._cython_call.send_receive_close() await self._cython_call.send_receive_close()
except asyncio.CancelledError: except asyncio.CancelledError:
@ -409,6 +438,18 @@ class _StreamRequestMixin(Call):
self.cancel() self.cancel()
await self._raise_for_status() await self._raise_for_status()
async def write(self, request: RequestType) -> None:
self._raise_for_different_style(_APIStyle.READER_WRITER)
await self._write(request)
async def done_writing(self) -> None:
"""Signal peer that client is done writing.
This method is idempotent.
"""
self._raise_for_different_style(_APIStyle.READER_WRITER)
await self._done_writing()
class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall): class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall):
"""Object for managing unary-unary RPC calls. """Object for managing unary-unary RPC calls.

@ -13,13 +13,15 @@
# limitations under the License. # limitations under the License.
"""Invocation-side implementation of gRPC Asyncio Python.""" """Invocation-side implementation of gRPC Asyncio Python."""
import asyncio import asyncio
from typing import Any, AsyncIterable, Optional, Sequence, AbstractSet, Text from typing import Any, AsyncIterable, Optional, Sequence, AbstractSet
from weakref import WeakSet from weakref import WeakSet
import logging import logging
import grpc import grpc
from grpc import _common from grpc import _common
from grpc._cython import cygrpc from grpc._cython import cygrpc
from grpc import _compression
from grpc import _grpcio_metadata
from . import _base_call from . import _base_call
from ._call import (StreamStreamCall, StreamUnaryCall, UnaryStreamCall, from ._call import (StreamStreamCall, StreamUnaryCall, UnaryStreamCall,
@ -31,6 +33,20 @@ from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType,
from ._utils import _timeout_to_deadline from ._utils import _timeout_to_deadline
_IMMUTABLE_EMPTY_TUPLE = tuple() _IMMUTABLE_EMPTY_TUPLE = tuple()
_USER_AGENT = 'grpc-python-asyncio/{}'.format(_grpcio_metadata.__version__)
def _augment_channel_arguments(base_options: ChannelArgumentType,
compression: Optional[grpc.Compression]):
compression_channel_argument = _compression.create_channel_option(
compression)
user_agent_channel_argument = ((
cygrpc.ChannelArgKey.primary_user_agent_string,
_USER_AGENT,
),)
return tuple(base_options
) + compression_channel_argument + user_agent_channel_argument
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -110,7 +126,7 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable):
request: Any, request: Any,
*, *,
timeout: Optional[float] = None, timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None, metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE,
credentials: Optional[grpc.CallCredentials] = None, credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None, wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None compression: Optional[grpc.Compression] = None
@ -139,10 +155,7 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable):
metadata, status code, and details. metadata, status code, and details.
""" """
if compression: if compression:
raise NotImplementedError("TODO: compression not implemented yet") metadata = _compression.augment_metadata(metadata, compression)
if metadata is None:
metadata = _IMMUTABLE_EMPTY_TUPLE
if not self._interceptors: if not self._interceptors:
call = UnaryUnaryCall(request, _timeout_to_deadline(timeout), call = UnaryUnaryCall(request, _timeout_to_deadline(timeout),
@ -168,7 +181,7 @@ class UnaryStreamMultiCallable(_BaseMultiCallable):
request: Any, request: Any,
*, *,
timeout: Optional[float] = None, timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None, metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE,
credentials: Optional[grpc.CallCredentials] = None, credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None, wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None compression: Optional[grpc.Compression] = None
@ -192,11 +205,9 @@ class UnaryStreamMultiCallable(_BaseMultiCallable):
A Call object instance which is an awaitable object. A Call object instance which is an awaitable object.
""" """
if compression: if compression:
raise NotImplementedError("TODO: compression not implemented yet") metadata = _compression.augment_metadata(metadata, compression)
deadline = _timeout_to_deadline(timeout) deadline = _timeout_to_deadline(timeout)
if metadata is None:
metadata = _IMMUTABLE_EMPTY_TUPLE
call = UnaryStreamCall(request, deadline, metadata, credentials, call = UnaryStreamCall(request, deadline, metadata, credentials,
wait_for_ready, self._channel, self._method, wait_for_ready, self._channel, self._method,
@ -212,7 +223,7 @@ class StreamUnaryMultiCallable(_BaseMultiCallable):
def __call__(self, def __call__(self,
request_async_iterator: Optional[AsyncIterable[Any]] = None, request_async_iterator: Optional[AsyncIterable[Any]] = None,
timeout: Optional[float] = None, timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None, metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE,
credentials: Optional[grpc.CallCredentials] = None, credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None, wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None compression: Optional[grpc.Compression] = None
@ -241,11 +252,9 @@ class StreamUnaryMultiCallable(_BaseMultiCallable):
metadata, status code, and details. metadata, status code, and details.
""" """
if compression: if compression:
raise NotImplementedError("TODO: compression not implemented yet") metadata = _compression.augment_metadata(metadata, compression)
deadline = _timeout_to_deadline(timeout) deadline = _timeout_to_deadline(timeout)
if metadata is None:
metadata = _IMMUTABLE_EMPTY_TUPLE
call = StreamUnaryCall(request_async_iterator, deadline, metadata, call = StreamUnaryCall(request_async_iterator, deadline, metadata,
credentials, wait_for_ready, self._channel, credentials, wait_for_ready, self._channel,
@ -261,7 +270,7 @@ class StreamStreamMultiCallable(_BaseMultiCallable):
def __call__(self, def __call__(self,
request_async_iterator: Optional[AsyncIterable[Any]] = None, request_async_iterator: Optional[AsyncIterable[Any]] = None,
timeout: Optional[float] = None, timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None, metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE,
credentials: Optional[grpc.CallCredentials] = None, credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None, wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None compression: Optional[grpc.Compression] = None
@ -290,11 +299,9 @@ class StreamStreamMultiCallable(_BaseMultiCallable):
metadata, status code, and details. metadata, status code, and details.
""" """
if compression: if compression:
raise NotImplementedError("TODO: compression not implemented yet") metadata = _compression.augment_metadata(metadata, compression)
deadline = _timeout_to_deadline(timeout) deadline = _timeout_to_deadline(timeout)
if metadata is None:
metadata = _IMMUTABLE_EMPTY_TUPLE
call = StreamStreamCall(request_async_iterator, deadline, metadata, call = StreamStreamCall(request_async_iterator, deadline, metadata,
credentials, wait_for_ready, self._channel, credentials, wait_for_ready, self._channel,
@ -314,7 +321,7 @@ class Channel:
_unary_unary_interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] _unary_unary_interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]]
_ongoing_calls: _OngoingCalls _ongoing_calls: _OngoingCalls
def __init__(self, target: Text, options: Optional[ChannelArgumentType], def __init__(self, target: str, options: ChannelArgumentType,
credentials: Optional[grpc.ChannelCredentials], credentials: Optional[grpc.ChannelCredentials],
compression: Optional[grpc.Compression], compression: Optional[grpc.Compression],
interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]]): interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]]):
@ -329,10 +336,6 @@ class Channel:
interceptors: An optional list of interceptors that would be used for interceptors: An optional list of interceptors that would be used for
intercepting any RPC executed with that channel. intercepting any RPC executed with that channel.
""" """
if compression:
raise NotImplementedError("TODO: compression not implemented yet")
if interceptors is None: if interceptors is None:
self._unary_unary_interceptors = None self._unary_unary_interceptors = None
else: else:
@ -352,8 +355,10 @@ class Channel:
.format(invalid_interceptors)) .format(invalid_interceptors))
self._loop = asyncio.get_event_loop() self._loop = asyncio.get_event_loop()
self._channel = cygrpc.AioChannel(_common.encode(target), options, self._channel = cygrpc.AioChannel(
credentials, self._loop) _common.encode(target),
_augment_channel_arguments(options, compression), credentials,
self._loop)
self._ongoing_calls = _OngoingCalls() self._ongoing_calls = _OngoingCalls()
async def __aenter__(self): async def __aenter__(self):
@ -456,9 +461,16 @@ class Channel:
assert await self._channel.watch_connectivity_state( assert await self._channel.watch_connectivity_state(
last_observed_state.value[0], None) last_observed_state.value[0], None)
async def channel_ready(self) -> None:
"""Creates a coroutine that ends when a Channel is ready."""
state = self.get_state(try_to_connect=True)
while state != grpc.ChannelConnectivity.READY:
await self.wait_for_state_change(state)
state = self.get_state(try_to_connect=True)
def unary_unary( def unary_unary(
self, self,
method: Text, method: str,
request_serializer: Optional[SerializingFunction] = None, request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None response_deserializer: Optional[DeserializingFunction] = None
) -> UnaryUnaryMultiCallable: ) -> UnaryUnaryMultiCallable:
@ -484,7 +496,7 @@ class Channel:
def unary_stream( def unary_stream(
self, self,
method: Text, method: str,
request_serializer: Optional[SerializingFunction] = None, request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None response_deserializer: Optional[DeserializingFunction] = None
) -> UnaryStreamMultiCallable: ) -> UnaryStreamMultiCallable:
@ -495,7 +507,7 @@ class Channel:
def stream_unary( def stream_unary(
self, self,
method: Text, method: str,
request_serializer: Optional[SerializingFunction] = None, request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None response_deserializer: Optional[DeserializingFunction] = None
) -> StreamUnaryMultiCallable: ) -> StreamUnaryMultiCallable:
@ -506,7 +518,7 @@ class Channel:
def stream_stream( def stream_stream(
self, self,
method: Text, method: str,
request_serializer: Optional[SerializingFunction] = None, request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None response_deserializer: Optional[DeserializingFunction] = None
) -> StreamStreamMultiCallable: ) -> StreamStreamMultiCallable:

@ -16,7 +16,7 @@ import asyncio
import collections import collections
import functools import functools
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from typing import Callable, Optional, Iterator, Sequence, Text, Union from typing import Callable, Optional, Iterator, Sequence, Union
import grpc import grpc
from grpc._cython import cygrpc from grpc._cython import cygrpc
@ -36,7 +36,7 @@ class ClientCallDetails(
('method', 'timeout', 'metadata', 'credentials', 'wait_for_ready')), ('method', 'timeout', 'metadata', 'credentials', 'wait_for_ready')),
grpc.ClientCallDetails): grpc.ClientCallDetails):
method: Text method: str
timeout: Optional[float] timeout: Optional[float]
metadata: Optional[MetadataType] metadata: Optional[MetadataType]
credentials: Optional[grpc.CallCredentials] credentials: Optional[grpc.CallCredentials]

@ -13,39 +13,52 @@
# limitations under the License. # limitations under the License.
"""Server-side implementation of gRPC Asyncio Python.""" """Server-side implementation of gRPC Asyncio Python."""
from typing import Text, Optional
import asyncio import asyncio
from concurrent.futures import Executor
from typing import Any, Optional, Sequence
import grpc import grpc
from grpc import _common from grpc import _common, _compression
from grpc._cython import cygrpc from grpc._cython import cygrpc
from ._typing import ChannelArgumentType
def _augment_channel_arguments(base_options: ChannelArgumentType,
compression: Optional[grpc.Compression]):
compression_option = _compression.create_channel_option(compression)
return tuple(base_options) + compression_option
class Server: class Server:
"""Serves RPCs.""" """Serves RPCs."""
def __init__(self, thread_pool, generic_handlers, interceptors, options, def __init__(self, thread_pool: Optional[Executor],
maximum_concurrent_rpcs, compression): generic_handlers: Optional[Sequence[grpc.GenericRpcHandler]],
interceptors: Optional[Sequence[Any]],
options: ChannelArgumentType,
maximum_concurrent_rpcs: Optional[int],
compression: Optional[grpc.Compression]):
self._loop = asyncio.get_event_loop() self._loop = asyncio.get_event_loop()
self._server = cygrpc.AioServer(self._loop, thread_pool, self._server = cygrpc.AioServer(
generic_handlers, interceptors, options, self._loop, thread_pool, generic_handlers, interceptors,
maximum_concurrent_rpcs, compression) _augment_channel_arguments(options, compression),
maximum_concurrent_rpcs)
def add_generic_rpc_handlers( def add_generic_rpc_handlers(
self, self,
generic_rpc_handlers, generic_rpc_handlers: Sequence[grpc.GenericRpcHandler]) -> None:
# generic_rpc_handlers: Iterable[grpc.GenericRpcHandlers]
) -> None:
"""Registers GenericRpcHandlers with this Server. """Registers GenericRpcHandlers with this Server.
This method is only safe to call before the server is started. This method is only safe to call before the server is started.
Args: Args:
generic_rpc_handlers: An iterable of GenericRpcHandlers that will be generic_rpc_handlers: A sequence of GenericRpcHandlers that will be
used to service RPCs. used to service RPCs.
""" """
self._server.add_generic_rpc_handlers(generic_rpc_handlers) self._server.add_generic_rpc_handlers(generic_rpc_handlers)
def add_insecure_port(self, address: Text) -> int: def add_insecure_port(self, address: str) -> int:
"""Opens an insecure port for accepting RPCs. """Opens an insecure port for accepting RPCs.
This method may only be called before starting the server. This method may only be called before starting the server.
@ -59,7 +72,7 @@ class Server:
""" """
return self._server.add_insecure_port(_common.encode(address)) return self._server.add_insecure_port(_common.encode(address))
def add_secure_port(self, address: Text, def add_secure_port(self, address: str,
server_credentials: grpc.ServerCredentials) -> int: server_credentials: grpc.ServerCredentials) -> int:
"""Opens a secure port for accepting RPCs. """Opens a secure port for accepting RPCs.
@ -141,12 +154,12 @@ class Server:
self._loop.create_task(self._server.shutdown(None)) self._loop.create_task(self._server.shutdown(None))
def server(migration_thread_pool=None, def server(migration_thread_pool: Optional[Executor] = None,
handlers=None, handlers: Optional[Sequence[grpc.GenericRpcHandler]] = None,
interceptors=None, interceptors: Optional[Sequence[Any]] = None,
options=None, options: Optional[ChannelArgumentType] = None,
maximum_concurrent_rpcs=None, maximum_concurrent_rpcs: Optional[int] = None,
compression=None): compression: Optional[grpc.Compression] = None):
"""Creates a Server with which RPCs can be serviced. """Creates a Server with which RPCs can be serviced.
Args: Args:
@ -166,7 +179,8 @@ def server(migration_thread_pool=None,
indicate no limit. indicate no limit.
compression: An element of grpc.compression, e.g. compression: An element of grpc.compression, e.g.
grpc.compression.Gzip. This compression algorithm will be used for the grpc.compression.Gzip. This compression algorithm will be used for the
lifetime of the server unless overridden. This is an EXPERIMENTAL option. lifetime of the server unless overridden by set_compression. This is an
EXPERIMENTAL option.
Returns: Returns:
A Server object. A Server object.

@ -13,15 +13,15 @@
# limitations under the License. # limitations under the License.
"""Common types for gRPC Async API""" """Common types for gRPC Async API"""
from typing import Any, AnyStr, Callable, Sequence, Text, Tuple, TypeVar from typing import Any, AnyStr, Callable, Sequence, Tuple, TypeVar
from grpc._cython.cygrpc import EOF from grpc._cython.cygrpc import EOF
RequestType = TypeVar('RequestType') RequestType = TypeVar('RequestType')
ResponseType = TypeVar('ResponseType') ResponseType = TypeVar('ResponseType')
SerializingFunction = Callable[[Any], bytes] SerializingFunction = Callable[[Any], bytes]
DeserializingFunction = Callable[[bytes], Any] DeserializingFunction = Callable[[bytes], Any]
MetadatumType = Tuple[Text, AnyStr] MetadatumType = Tuple[str, AnyStr]
MetadataType = Sequence[MetadatumType] MetadataType = Sequence[MetadatumType]
ChannelArgumentType = Sequence[Tuple[Text, Any]] ChannelArgumentType = Sequence[Tuple[str, Any]]
EOFType = type(EOF) EOFType = type(EOF)
DoneCallbackType = Callable[[Any], None] DoneCallbackType = Callable[[Any], None]

@ -1,3 +1,4 @@
load("@rules_proto//proto:defs.bzl", "proto_library")
load("@com_github_grpc_grpc//bazel:python_rules.bzl", "py_grpc_library", "py_proto_library") load("@com_github_grpc_grpc//bazel:python_rules.bzl", "py_grpc_library", "py_proto_library")
proto_library( proto_library(

@ -9,9 +9,11 @@
"unit.call_test.TestUnaryStreamCall", "unit.call_test.TestUnaryStreamCall",
"unit.call_test.TestUnaryUnaryCall", "unit.call_test.TestUnaryUnaryCall",
"unit.channel_argument_test.TestChannelArgument", "unit.channel_argument_test.TestChannelArgument",
"unit.channel_ready_test.TestChannelReady",
"unit.channel_test.TestChannel", "unit.channel_test.TestChannel",
"unit.close_channel_test.TestCloseChannel", "unit.close_channel_test.TestCloseChannel",
"unit.close_channel_test.TestOngoingCalls", "unit.close_channel_test.TestOngoingCalls",
"unit.compression_test.TestCompression",
"unit.connectivity_test.TestConnectivityState", "unit.connectivity_test.TestConnectivityState",
"unit.done_callback_test.TestDoneCallback", "unit.done_callback_test.TestDoneCallback",
"unit.init_test.TestInsecureChannel", "unit.init_test.TestInsecureChannel",

@ -0,0 +1,67 @@
# Copyright 2020 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing the channel_ready function."""
import asyncio
import gc
import logging
import time
import unittest
import grpc
from grpc.experimental import aio
from tests.unit.framework.common import get_socket, test_constants
from tests_aio.unit import _common
from tests_aio.unit._test_base import AioTestBase
from tests_aio.unit._test_server import start_test_server
class TestChannelReady(AioTestBase):
async def setUp(self):
address, self._port, self._socket = get_socket(listen=False)
self._channel = aio.insecure_channel(f"{address}:{self._port}")
self._socket.close()
async def tearDown(self):
await self._channel.close()
async def test_channel_ready_success(self):
# Start `channel_ready` as another Task
channel_ready_task = self.loop.create_task(
self._channel.channel_ready())
# Wait for TRANSIENT_FAILURE
await _common.block_until_certain_state(
self._channel, grpc.ChannelConnectivity.TRANSIENT_FAILURE)
try:
# Start the server
_, server = await start_test_server(port=self._port)
# The RPC should recover itself
await channel_ready_task
finally:
await server.stop(None)
async def test_channel_ready_blocked(self):
with self.assertRaises(asyncio.TimeoutError):
await asyncio.wait_for(self._channel.channel_ready(),
test_constants.SHORT_TIMEOUT)
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2)

@ -0,0 +1,196 @@
# Copyright 2020 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests behavior around the compression mechanism."""
import asyncio
import logging
import platform
import random
import unittest
import grpc
from grpc.experimental import aio
from tests_aio.unit._test_base import AioTestBase
from tests_aio.unit import _common
_GZIP_CHANNEL_ARGUMENT = ('grpc.default_compression_algorithm', 2)
_GZIP_DISABLED_CHANNEL_ARGUMENT = ('grpc.compression_enabled_algorithms_bitset',
3)
_DEFLATE_DISABLED_CHANNEL_ARGUMENT = (
'grpc.compression_enabled_algorithms_bitset', 5)
_TEST_UNARY_UNARY = '/test/TestUnaryUnary'
_TEST_SET_COMPRESSION = '/test/TestSetCompression'
_TEST_DISABLE_COMPRESSION_UNARY = '/test/TestDisableCompressionUnary'
_TEST_DISABLE_COMPRESSION_STREAM = '/test/TestDisableCompressionStream'
_REQUEST = b'\x01' * 100
_RESPONSE = b'\x02' * 100
async def _test_unary_unary(unused_request, unused_context):
return _RESPONSE
async def _test_set_compression(unused_request_iterator, context):
assert _REQUEST == await context.read()
context.set_compression(grpc.Compression.Deflate)
await context.write(_RESPONSE)
try:
context.set_compression(grpc.Compression.Deflate)
except RuntimeError:
# NOTE(lidiz) Testing if the servicer context raises exception when
# the set_compression method is called after initial_metadata sent.
# After the initial_metadata sent, the server-side has no control over
# which compression algorithm it should use.
pass
else:
raise ValueError(
'Expecting exceptions if set_compression is not effective')
async def _test_disable_compression_unary(request, context):
assert _REQUEST == request
context.set_compression(grpc.Compression.Deflate)
context.disable_next_message_compression()
return _RESPONSE
async def _test_disable_compression_stream(unused_request_iterator, context):
assert _REQUEST == await context.read()
context.set_compression(grpc.Compression.Deflate)
await context.write(_RESPONSE)
context.disable_next_message_compression()
await context.write(_RESPONSE)
await context.write(_RESPONSE)
_ROUTING_TABLE = {
_TEST_UNARY_UNARY:
grpc.unary_unary_rpc_method_handler(_test_unary_unary),
_TEST_SET_COMPRESSION:
grpc.stream_stream_rpc_method_handler(_test_set_compression),
_TEST_DISABLE_COMPRESSION_UNARY:
grpc.unary_unary_rpc_method_handler(_test_disable_compression_unary),
_TEST_DISABLE_COMPRESSION_STREAM:
grpc.stream_stream_rpc_method_handler(_test_disable_compression_stream),
}
class _GenericHandler(grpc.GenericRpcHandler):
def service(self, handler_call_details):
return _ROUTING_TABLE.get(handler_call_details.method)
async def _start_test_server(options=None):
server = aio.server(options=options)
port = server.add_insecure_port('[::]:0')
server.add_generic_rpc_handlers((_GenericHandler(),))
await server.start()
return f'localhost:{port}', server
class TestCompression(AioTestBase):
async def setUp(self):
server_options = (_GZIP_DISABLED_CHANNEL_ARGUMENT,)
self._address, self._server = await _start_test_server(server_options)
self._channel = aio.insecure_channel(self._address)
async def tearDown(self):
await self._channel.close()
await self._server.stop(None)
async def test_channel_level_compression_baned_compression(self):
# GZIP is disabled, this call should fail
async with aio.insecure_channel(
self._address, compression=grpc.Compression.Gzip) as channel:
multicallable = channel.unary_unary(_TEST_UNARY_UNARY)
call = multicallable(_REQUEST)
with self.assertRaises(aio.AioRpcError) as exception_context:
await call
rpc_error = exception_context.exception
self.assertEqual(grpc.StatusCode.UNIMPLEMENTED, rpc_error.code())
async def test_channel_level_compression_allowed_compression(self):
# Deflate is allowed, this call should succeed
async with aio.insecure_channel(
self._address, compression=grpc.Compression.Deflate) as channel:
multicallable = channel.unary_unary(_TEST_UNARY_UNARY)
call = multicallable(_REQUEST)
self.assertEqual(grpc.StatusCode.OK, await call.code())
async def test_client_call_level_compression_baned_compression(self):
multicallable = self._channel.unary_unary(_TEST_UNARY_UNARY)
# GZIP is disabled, this call should fail
call = multicallable(_REQUEST, compression=grpc.Compression.Gzip)
with self.assertRaises(aio.AioRpcError) as exception_context:
await call
rpc_error = exception_context.exception
self.assertEqual(grpc.StatusCode.UNIMPLEMENTED, rpc_error.code())
async def test_client_call_level_compression_allowed_compression(self):
multicallable = self._channel.unary_unary(_TEST_UNARY_UNARY)
# Deflate is allowed, this call should succeed
call = multicallable(_REQUEST, compression=grpc.Compression.Deflate)
self.assertEqual(grpc.StatusCode.OK, await call.code())
async def test_server_call_level_compression(self):
multicallable = self._channel.stream_stream(_TEST_SET_COMPRESSION)
call = multicallable()
await call.write(_REQUEST)
await call.done_writing()
self.assertEqual(_RESPONSE, await call.read())
self.assertEqual(grpc.StatusCode.OK, await call.code())
async def test_server_disable_compression_unary(self):
multicallable = self._channel.unary_unary(
_TEST_DISABLE_COMPRESSION_UNARY)
call = multicallable(_REQUEST)
self.assertEqual(_RESPONSE, await call)
self.assertEqual(grpc.StatusCode.OK, await call.code())
async def test_server_disable_compression_stream(self):
multicallable = self._channel.stream_stream(
_TEST_DISABLE_COMPRESSION_STREAM)
call = multicallable()
await call.write(_REQUEST)
await call.done_writing()
self.assertEqual(_RESPONSE, await call.read())
self.assertEqual(_RESPONSE, await call.read())
self.assertEqual(_RESPONSE, await call.read())
self.assertEqual(grpc.StatusCode.OK, await call.code())
async def test_server_default_compression_algorithm(self):
server = aio.server(compression=grpc.Compression.Deflate)
port = server.add_insecure_port('[::]:0')
server.add_generic_rpc_handlers((_GenericHandler(),))
await server.start()
async with aio.insecure_channel(f'localhost:{port}') as channel:
multicallable = channel.unary_unary(_TEST_UNARY_UNARY)
call = multicallable(_REQUEST)
self.assertEqual(_RESPONSE, await call)
self.assertEqual(grpc.StatusCode.OK, await call.code())
await server.stop(None)
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2)

@ -102,7 +102,7 @@ class TestConnectivityState(AioTestBase):
# It can raise exceptions since it is an usage error, but it should not # It can raise exceptions since it is an usage error, but it should not
# segfault or abort. # segfault or abort.
with self.assertRaises(RuntimeError): with self.assertRaises(aio.UsageError):
await channel.wait_for_state_change( await channel.wait_for_state_change(
grpc.ChannelConnectivity.SHUTDOWN) grpc.ChannelConnectivity.SHUTDOWN)

@ -231,14 +231,10 @@ class TestServer(AioTestBase):
# Uses reader API # Uses reader API
self.assertEqual(_RESPONSE, await call.read()) self.assertEqual(_RESPONSE, await call.read())
# Uses async generator API # Uses async generator API, mixed!
response_cnt = 0 with self.assertRaises(aio.UsageError):
async for response in call: async for response in call:
response_cnt += 1 self.assertEqual(_RESPONSE, response)
self.assertEqual(_RESPONSE, response)
self.assertEqual(_NUM_STREAM_RESPONSES - 1, response_cnt)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_stream_unary_async_generator(self): async def test_stream_unary_async_generator(self):
stream_unary_call = self._channel.stream_unary(_STREAM_UNARY_ASYNC_GEN) stream_unary_call = self._channel.stream_unary(_STREAM_UNARY_ASYNC_GEN)

@ -991,7 +991,8 @@ class XdsEnd2endTest : public ::testing::TestWithParam<TestType> {
} }
std::tuple<int, int, int> WaitForAllBackends(size_t start_index = 0, std::tuple<int, int, int> WaitForAllBackends(size_t start_index = 0,
size_t stop_index = 0) { size_t stop_index = 0,
bool reset_counters = true) {
int num_ok = 0; int num_ok = 0;
int num_failure = 0; int num_failure = 0;
int num_drops = 0; int num_drops = 0;
@ -999,7 +1000,7 @@ class XdsEnd2endTest : public ::testing::TestWithParam<TestType> {
while (!SeenAllBackends(start_index, stop_index)) { while (!SeenAllBackends(start_index, stop_index)) {
SendRpcAndCount(&num_total, &num_ok, &num_failure, &num_drops); SendRpcAndCount(&num_total, &num_ok, &num_failure, &num_drops);
} }
ResetBackendCounters(); if (reset_counters) ResetBackendCounters();
gpr_log(GPR_INFO, gpr_log(GPR_INFO,
"Performed %d warm up requests against the backends. " "Performed %d warm up requests against the backends. "
"%d succeeded, %d failed, %d dropped.", "%d succeeded, %d failed, %d dropped.",
@ -2202,6 +2203,41 @@ TEST_P(FailoverTest, UpdatePriority) {
EXPECT_EQ(2U, balancers_[0]->ads_service()->response_count()); EXPECT_EQ(2U, balancers_[0]->ads_service()->response_count());
} }
// Moves all localities in the current priority to a higher priority.
TEST_P(FailoverTest, MoveAllLocalitiesInCurrentPriorityToHigherPriority) {
SetNextResolution({});
SetNextResolutionForLbChannelAllBalancers();
// First update:
// - Priority 0 is locality 0, containing backend 0, which is down.
// - Priority 1 is locality 1, containing backends 1 and 2, which are up.
ShutdownBackend(0);
AdsServiceImpl::ResponseArgs args({
{"locality0", GetBackendPorts(0, 1), kDefaultLocalityWeight, 0},
{"locality1", GetBackendPorts(1, 3), kDefaultLocalityWeight, 1},
});
ScheduleResponseForBalancer(0, AdsServiceImpl::BuildResponse(args), 0);
// Second update:
// - Priority 0 contains both localities 0 and 1.
// - Priority 1 is not present.
// - We add backend 3 to locality 1, just so we have a way to know
// when the update has been seen by the client.
args = AdsServiceImpl::ResponseArgs({
{"locality0", GetBackendPorts(0, 1), kDefaultLocalityWeight, 0},
{"locality1", GetBackendPorts(1, 4), kDefaultLocalityWeight, 0},
});
ScheduleResponseForBalancer(0, AdsServiceImpl::BuildResponse(args), 1000);
// When we get the first update, all backends in priority 0 are down,
// so we will create priority 1. Backends 1 and 2 should have traffic,
// but backend 3 should not.
WaitForAllBackends(1, 3, false);
EXPECT_EQ(0UL, backends_[3]->backend_service()->request_count());
// When backend 3 gets traffic, we know the second update has been seen.
WaitForBackend(3);
// The ADS service got a single request, and sent a single response.
EXPECT_EQ(1U, balancers_[0]->ads_service()->request_count());
EXPECT_EQ(2U, balancers_[0]->ads_service()->response_count());
}
using DropTest = BasicTest; using DropTest = BasicTest;
// Tests that RPCs are dropped according to the drop config. // Tests that RPCs are dropped according to the drop config.

@ -14,9 +14,9 @@ _PYTHON3_BIN_PATH = "PYTHON3_BIN_PATH"
_PYTHON3_LIB_PATH = "PYTHON3_LIB_PATH" _PYTHON3_LIB_PATH = "PYTHON3_LIB_PATH"
_HEADERS_HELP = ( _HEADERS_HELP = (
"Are Python headers installed? Try installing python-dev or " + "Are Python headers installed? Try installing python-dev or " +
"python3-dev on Debian-based systems. Try python-devel or python3-devel " + "python3-dev on Debian-based systems. Try python-devel or python3-devel " +
"on Redhat-based systems." "on Redhat-based systems."
) )
def _tpl(repository_ctx, tpl, substitutions = {}, out = None): def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
@ -246,11 +246,11 @@ def _get_python_include(repository_ctx, python_bin):
_execute( _execute(
repository_ctx, repository_ctx,
[ [
python_bin, python_bin,
"-c", "-c",
"import os;" + "import os;" +
"main_header = os.path.join('{}', 'Python.h');".format(include_path) + "main_header = os.path.join('{}', 'Python.h');".format(include_path) +
"assert os.path.exists(main_header), main_header + ' does not exist.'" "assert os.path.exists(main_header), main_header + ' does not exist.'",
], ],
error_msg = "Unable to find Python headers for {}".format(python_bin), error_msg = "Unable to find Python headers for {}".format(python_bin),
error_details = _HEADERS_HELP, error_details = _HEADERS_HELP,

@ -1,3 +1,4 @@
load("@rules_proto//proto:defs.bzl", "proto_library")
load( load(
"//bazel:build_defs.bzl", "//bazel:build_defs.bzl",
"generated_file_staleness_test", "generated_file_staleness_test",
@ -56,13 +57,13 @@ config_setting(
cc_library( cc_library(
name = "port", name = "port",
srcs = [
"upb/port.c",
],
textual_hdrs = [ textual_hdrs = [
"upb/port_def.inc", "upb/port_def.inc",
"upb/port_undef.inc", "upb/port_undef.inc",
], ],
srcs = [
"upb/port.c",
],
) )
cc_library( cc_library(
@ -159,8 +160,8 @@ cc_library(
cc_library( cc_library(
name = "legacy_msg_reflection", name = "legacy_msg_reflection",
srcs = [ srcs = [
"upb/msg.h",
"upb/legacy_msg_reflection.c", "upb/legacy_msg_reflection.c",
"upb/msg.h",
], ],
hdrs = ["upb/legacy_msg_reflection.h"], hdrs = ["upb/legacy_msg_reflection.h"],
copts = select({ copts = select({
@ -190,8 +191,8 @@ cc_library(
"//conditions:default": COPTS, "//conditions:default": COPTS,
}), }),
deps = [ deps = [
":reflection",
":port", ":port",
":reflection",
":table", ":table",
":upb", ":upb",
], ],
@ -220,8 +221,8 @@ cc_library(
deps = [ deps = [
":descriptor_upbproto", ":descriptor_upbproto",
":handlers", ":handlers",
":reflection",
":port", ":port",
":reflection",
":table", ":table",
":upb", ":upb",
], ],

@ -8,6 +8,7 @@ load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain")
# copybara:strip_for_google3_begin # copybara:strip_for_google3_begin
load("@bazel_skylib//lib:versions.bzl", "versions") load("@bazel_skylib//lib:versions.bzl", "versions")
load("@rules_proto//proto:defs.bzl", "ProtoInfo")
load("@upb_bazel_version//:bazel_version.bzl", "bazel_version") load("@upb_bazel_version//:bazel_version.bzl", "bazel_version")
# copybara:strip_end # copybara:strip_end
@ -22,6 +23,7 @@ def _get_real_short_path(file):
if short_path.startswith("../"): if short_path.startswith("../"):
second_slash = short_path.index("/", 3) second_slash = short_path.index("/", 3)
short_path = short_path[second_slash + 1:] short_path = short_path[second_slash + 1:]
# Sometimes it has another few prefixes like: # Sometimes it has another few prefixes like:
# _virtual_imports/any_proto/google/protobuf/any.proto # _virtual_imports/any_proto/google/protobuf/any.proto
# We want just google/protobuf/any.proto. # We want just google/protobuf/any.proto.

@ -1,3 +1,4 @@
load("@rules_proto//proto:defs.bzl", "proto_library")
load("@upb//bazel:upb_proto_library.bzl", "upb_proto_library") load("@upb//bazel:upb_proto_library.bzl", "upb_proto_library")
proto_library( proto_library(

@ -152,6 +152,7 @@ LANG_RELEASE_MATRIX = {
('v1.24.0', ReleaseInfo(runtimes=['go1.11'])), ('v1.24.0', ReleaseInfo(runtimes=['go1.11'])),
('v1.25.0', ReleaseInfo(runtimes=['go1.11'])), ('v1.25.0', ReleaseInfo(runtimes=['go1.11'])),
('v1.26.0', ReleaseInfo(runtimes=['go1.11'])), ('v1.26.0', ReleaseInfo(runtimes=['go1.11'])),
('v1.27.1', ReleaseInfo(runtimes=['go1.11'])),
]), ]),
'java': 'java':
OrderedDict([ OrderedDict([

@ -0,0 +1,36 @@
#!/bin/bash
# Copyright 2020 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -ex
# change to grpc repo root
cd "$(dirname "$0")/../../.."
sudo apt-get install -y python3-pip
sudo python3 -m pip install grpcio grpcio-tools google-api-python-client google-auth-httplib2
# Prepare generated Python code.
TOOLS_DIR=tools/run_tests
PROTO_SOURCE_DIR=src/proto/grpc/testing
PROTO_DEST_DIR=${TOOLS_DIR}/${PROTO_SOURCE_DIR}
mkdir -p ${PROTO_DEST_DIR}
python3 -m grpc_tools.protoc \
--proto_path=. \
--python_out=${TOOLS_DIR} \
--grpc_python_out=${TOOLS_DIR} \
${PROTO_SOURCE_DIR}/test.proto \
${PROTO_SOURCE_DIR}/messages.proto \
${PROTO_SOURCE_DIR}/empty.proto

@ -0,0 +1,595 @@
#!/usr/bin/env python
# Copyright 2020 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run xDS integration tests on GCP using Traffic Director."""
import argparse
import googleapiclient.discovery
import grpc
import logging
import os
import shlex
import socket
import subprocess
import sys
import tempfile
import time
from src.proto.grpc.testing import messages_pb2
from src.proto.grpc.testing import test_pb2_grpc
logger = logging.getLogger(__name__)
console_handler = logging.StreamHandler()
logger.addHandler(console_handler)
argp = argparse.ArgumentParser(description='Run xDS interop tests on GCP')
argp.add_argument('--project_id', help='GCP project id')
argp.add_argument(
'--gcp_suffix',
default='',
help='Optional suffix for all generated GCP resource names. Useful to ensure '
'distinct names across test runs.')
argp.add_argument('--test_case',
default=None,
choices=['all', 'ping_pong', 'round_robin'])
argp.add_argument(
'--client_cmd',
default=None,
help='Command to launch xDS test client. This script will fill in '
'{service_host}, {service_port},{stats_port} and {qps} parameters using '
'str.format(), and generate the GRPC_XDS_BOOTSTRAP file.')
argp.add_argument('--zone', default='us-central1-a')
argp.add_argument('--qps', default=10, help='Client QPS')
argp.add_argument(
'--wait_for_backend_sec',
default=900,
help='Time limit for waiting for created backend services to report healthy '
'when launching test suite')
argp.add_argument(
'--keep_gcp_resources',
default=False,
action='store_true',
help=
'Leave GCP VMs and configuration running after test. Default behavior is '
'to delete when tests complete.')
argp.add_argument(
'--tolerate_gcp_errors',
default=False,
action='store_true',
help=
'Continue with test even when an error occurs during setup. Intended for '
'manual testing, where attempts to recreate any GCP resources already '
'existing will result in an error')
argp.add_argument('--verbose',
help='verbose log output',
default=False,
action="store_true")
args = argp.parse_args()
if args.verbose:
logger.setLevel(logging.DEBUG)
PROJECT_ID = args.project_id
ZONE = args.zone
QPS = args.qps
TEST_CASE = args.test_case
CLIENT_CMD = args.client_cmd
WAIT_FOR_BACKEND_SEC = args.wait_for_backend_sec
TEMPLATE_NAME = 'test-template' + args.gcp_suffix
INSTANCE_GROUP_NAME = 'test-ig' + args.gcp_suffix
HEALTH_CHECK_NAME = 'test-hc' + args.gcp_suffix
FIREWALL_RULE_NAME = 'test-fw-rule' + args.gcp_suffix
BACKEND_SERVICE_NAME = 'test-backend-service' + args.gcp_suffix
URL_MAP_NAME = 'test-map' + args.gcp_suffix
SERVICE_HOST = 'grpc-test' + args.gcp_suffix
TARGET_PROXY_NAME = 'test-target-proxy' + args.gcp_suffix
FORWARDING_RULE_NAME = 'test-forwarding-rule' + args.gcp_suffix
KEEP_GCP_RESOURCES = args.keep_gcp_resources
TOLERATE_GCP_ERRORS = args.tolerate_gcp_errors
SERVICE_PORT = 55551
STATS_PORT = 55552
INSTANCE_GROUP_SIZE = 2
WAIT_FOR_OPERATION_SEC = 60
NUM_TEST_RPCS = 10 * QPS
WAIT_FOR_STATS_SEC = 30
BOOTSTRAP_TEMPLATE = """
{{
"node": {{
"id": "{node_id}"
}},
"xds_servers": [{{
"server_uri": "trafficdirector.googleapis.com:443",
"channel_creds": [
{{
"type": "google_default",
"config": {{}}
}}
]
}}]
}}"""
def get_client_stats(num_rpcs, timeout_sec):
with grpc.insecure_channel('localhost:%d' % STATS_PORT) as channel:
stub = test_pb2_grpc.LoadBalancerStatsServiceStub(channel)
request = messages_pb2.LoadBalancerStatsRequest()
request.num_rpcs = num_rpcs
request.timeout_sec = timeout_sec
try:
response = stub.GetClientStats(request, wait_for_ready=True)
logger.debug('Invoked GetClientStats RPC: %s', response)
return response
except grpc.RpcError as rpc_error:
raise Exception('GetClientStats RPC failed')
def wait_until_only_given_backends_receive_load(backends, timeout_sec):
start_time = time.time()
error_msg = None
while time.time() - start_time <= timeout_sec:
error_msg = None
stats = get_client_stats(max(len(backends), 1), timeout_sec)
rpcs_by_peer = stats.rpcs_by_peer
for backend in backends:
if backend not in rpcs_by_peer:
error_msg = 'Backend %s did not receive load' % backend
break
if not error_msg and len(rpcs_by_peer) > len(backends):
error_msg = 'Unexpected backend received load: %s' % rpcs_by_peer
if not error_msg:
return
raise Exception(error_msg)
def test_ping_pong(backends, num_rpcs, stats_timeout_sec):
start_time = time.time()
error_msg = None
while time.time() - start_time <= stats_timeout_sec:
error_msg = None
stats = get_client_stats(num_rpcs, stats_timeout_sec)
rpcs_by_peer = stats.rpcs_by_peer
for backend in backends:
if backend not in rpcs_by_peer:
error_msg = 'Backend %s did not receive load' % backend
break
if not error_msg and len(rpcs_by_peer) > len(backends):
error_msg = 'Unexpected backend received load: %s' % rpcs_by_peer
if not error_msg:
return
raise Exception(error_msg)
def test_round_robin(backends, num_rpcs, stats_timeout_sec):
threshold = 1
wait_until_only_given_backends_receive_load(backends, stats_timeout_sec)
stats = get_client_stats(num_rpcs, stats_timeout_sec)
requests_received = [stats.rpcs_by_peer[x] for x in stats.rpcs_by_peer]
total_requests_received = sum(
[stats.rpcs_by_peer[x] for x in stats.rpcs_by_peer])
if total_requests_received != num_rpcs:
raise Exception('Unexpected RPC failures', stats)
expected_requests = total_requests_received / len(backends)
for backend in backends:
if abs(stats.rpcs_by_peer[backend] - expected_requests) > threshold:
raise Exception(
'RPC peer distribution differs from expected by more than %d for backend %s (%s)',
threshold, backend, stats)
def create_instance_template(compute, name, grpc_port, project):
config = {
'name': name,
'properties': {
'tags': {
'items': ['grpc-allow-healthcheck']
},
'machineType': 'e2-standard-2',
'serviceAccounts': [{
'email': 'default',
'scopes': ['https://www.googleapis.com/auth/cloud-platform',]
}],
'networkInterfaces': [{
'accessConfigs': [{
'type': 'ONE_TO_ONE_NAT'
}],
'network': 'global/networks/default'
}],
'disks': [{
'boot': True,
'initializeParams': {
'sourceImage':
'projects/debian-cloud/global/images/family/debian-9'
}
}],
'metadata': {
'items': [{
'key':
'startup-script',
'value':
"""#!/bin/bash
sudo apt update
sudo apt install -y git default-jdk
mkdir java_server
pushd java_server
git clone https://github.com/grpc/grpc-java.git
pushd grpc-java
pushd interop-testing
../gradlew installDist -x test -PskipCodegen=true -PskipAndroid=true
nohup build/install/grpc-interop-testing/bin/xds-test-server --port=%d 1>/dev/null &"""
% grpc_port
}]
}
}
}
result = compute.instanceTemplates().insert(project=project,
body=config).execute()
wait_for_global_operation(compute, project, result['name'])
return result['targetLink']
def create_instance_group(compute, name, size, grpc_port, template_url, project,
zone):
config = {
'name': name,
'instanceTemplate': template_url,
'targetSize': size,
'namedPorts': [{
'name': 'grpc',
'port': grpc_port
}]
}
result = compute.instanceGroupManagers().insert(project=project,
zone=zone,
body=config).execute()
wait_for_zone_operation(compute, project, zone, result['name'])
result = compute.instanceGroupManagers().get(
project=PROJECT_ID, zone=ZONE, instanceGroupManager=name).execute()
return result['instanceGroup']
def create_health_check(compute, name, project):
config = {
'name': name,
'type': 'TCP',
'tcpHealthCheck': {
'portName': 'grpc'
}
}
result = compute.healthChecks().insert(project=project,
body=config).execute()
wait_for_global_operation(compute, project, result['name'])
return result['targetLink']
def create_health_check_firewall_rule(compute, name, project):
config = {
'name': name,
'direction': 'INGRESS',
'allowed': [{
'IPProtocol': 'tcp'
}],
'sourceRanges': ['35.191.0.0/16', '130.211.0.0/22'],
'targetTags': ['grpc-allow-healthcheck'],
}
result = compute.firewalls().insert(project=project, body=config).execute()
wait_for_global_operation(compute, project, result['name'])
def create_backend_service(compute, name, instance_group, health_check,
project):
config = {
'name': name,
'loadBalancingScheme': 'INTERNAL_SELF_MANAGED',
'healthChecks': [health_check],
'portName': 'grpc',
'protocol': 'HTTP2',
'backends': [{
'group': instance_group,
}]
}
result = compute.backendServices().insert(project=project,
body=config).execute()
wait_for_global_operation(compute, project, result['name'])
return result['targetLink']
def create_url_map(compute, name, backend_service_url, host_name, project):
path_matcher_name = 'path-matcher'
config = {
'name': name,
'defaultService': backend_service_url,
'pathMatchers': [{
'name': path_matcher_name,
'defaultService': backend_service_url,
}],
'hostRules': [{
'hosts': [host_name],
'pathMatcher': path_matcher_name
}]
}
result = compute.urlMaps().insert(project=project, body=config).execute()
wait_for_global_operation(compute, project, result['name'])
return result['targetLink']
def create_target_http_proxy(compute, name, url_map_url, project):
config = {
'name': name,
'url_map': url_map_url,
}
result = compute.targetHttpProxies().insert(project=project,
body=config).execute()
wait_for_global_operation(compute, project, result['name'])
return result['targetLink']
def create_global_forwarding_rule(compute, name, grpc_port,
target_http_proxy_url, project):
config = {
'name': name,
'loadBalancingScheme': 'INTERNAL_SELF_MANAGED',
'portRange': str(grpc_port),
'IPAddress': '0.0.0.0',
'target': target_http_proxy_url,
}
result = compute.globalForwardingRules().insert(project=project,
body=config).execute()
wait_for_global_operation(compute, project, result['name'])
def delete_global_forwarding_rule(compute, project, forwarding_rule):
try:
result = compute.globalForwardingRules().delete(
project=project, forwardingRule=forwarding_rule).execute()
wait_for_global_operation(compute, project, result['name'])
except googleapiclient.errors.HttpError as http_error:
logger.info('Delete failed: %s', http_error)
def delete_target_http_proxy(compute, project, target_http_proxy):
try:
result = compute.targetHttpProxies().delete(
project=project, targetHttpProxy=target_http_proxy).execute()
wait_for_global_operation(compute, project, result['name'])
except googleapiclient.errors.HttpError as http_error:
logger.info('Delete failed: %s', http_error)
def delete_url_map(compute, project, url_map):
try:
result = compute.urlMaps().delete(project=project,
urlMap=url_map).execute()
wait_for_global_operation(compute, project, result['name'])
except googleapiclient.errors.HttpError as http_error:
logger.info('Delete failed: %s', http_error)
def delete_backend_service(compute, project, backend_service):
try:
result = compute.backendServices().delete(
project=project, backendService=backend_service).execute()
wait_for_global_operation(compute, project, result['name'])
except googleapiclient.errors.HttpError as http_error:
logger.info('Delete failed: %s', http_error)
def delete_firewall(compute, project, firewall_rule):
try:
result = compute.firewalls().delete(project=project,
firewall=firewall_rule).execute()
wait_for_global_operation(compute, project, result['name'])
except googleapiclient.errors.HttpError as http_error:
logger.info('Delete failed: %s', http_error)
def delete_health_check(compute, project, health_check):
try:
result = compute.healthChecks().delete(
project=project, healthCheck=health_check).execute()
wait_for_global_operation(compute, project, result['name'])
except googleapiclient.errors.HttpError as http_error:
logger.info('Delete failed: %s', http_error)
def delete_instance_group(compute, project, zone, instance_group):
try:
result = compute.instanceGroupManagers().delete(
project=project, zone=zone,
instanceGroupManager=instance_group).execute()
timeout_sec = 180 # Deleting an instance group can be slow
wait_for_zone_operation(compute,
project,
ZONE,
result['name'],
timeout_sec=timeout_sec)
except googleapiclient.errors.HttpError as http_error:
logger.info('Delete failed: %s', http_error)
def delete_instance_template(compute, project, instance_template):
try:
result = compute.instanceTemplates().delete(
project=project, instanceTemplate=instance_template).execute()
wait_for_global_operation(compute, project, result['name'])
except googleapiclient.errors.HttpError as http_error:
logger.info('Delete failed: %s', http_error)
def wait_for_global_operation(compute,
project,
operation,
timeout_sec=WAIT_FOR_OPERATION_SEC):
start_time = time.time()
while time.time() - start_time <= timeout_sec:
result = compute.globalOperations().get(project=project,
operation=operation).execute()
if result['status'] == 'DONE':
if 'error' in result:
raise Exception(result['error'])
return
time.sleep(1)
raise Exception('Operation %s did not complete within %d', operation,
timeout_sec)
def wait_for_zone_operation(compute,
project,
zone,
operation,
timeout_sec=WAIT_FOR_OPERATION_SEC):
start_time = time.time()
while time.time() - start_time <= timeout_sec:
result = compute.zoneOperations().get(project=project,
zone=zone,
operation=operation).execute()
if result['status'] == 'DONE':
if 'error' in result:
raise Exception(result['error'])
return
time.sleep(1)
raise Exception('Operation %s did not complete within %d', operation,
timeout_sec)
def wait_for_healthy_backends(compute, project_id, backend_service,
instance_group_url, timeout_sec):
start_time = time.time()
config = {'group': instance_group_url}
while time.time() - start_time <= timeout_sec:
result = compute.backendServices().getHealth(
project=project_id, backendService=backend_service,
body=config).execute()
if 'healthStatus' in result:
healthy = True
for instance in result['healthStatus']:
if instance['healthState'] != 'HEALTHY':
healthy = False
break
if healthy:
return
time.sleep(1)
raise Exception('Not all backends became healthy within %d seconds: %s' %
(timeout_sec, result))
def start_xds_client():
cmd = CLIENT_CMD.format(service_host=SERVICE_HOST,
service_port=SERVICE_PORT,
stats_port=STATS_PORT,
qps=QPS)
bootstrap_path = None
with tempfile.NamedTemporaryFile(delete=False) as bootstrap_file:
bootstrap_file.write(
BOOTSTRAP_TEMPLATE.format(
node_id=socket.gethostname()).encode('utf-8'))
bootstrap_path = bootstrap_file.name
client_process = subprocess.Popen(shlex.split(cmd),
env=dict(
os.environ,
GRPC_XDS_BOOTSTRAP=bootstrap_path))
return client_process
compute = googleapiclient.discovery.build('compute', 'v1')
client_process = None
try:
instance_group_url = None
try:
template_url = create_instance_template(compute, TEMPLATE_NAME,
SERVICE_PORT, PROJECT_ID)
instance_group_url = create_instance_group(compute, INSTANCE_GROUP_NAME,
INSTANCE_GROUP_SIZE,
SERVICE_PORT, template_url,
PROJECT_ID, ZONE)
health_check_url = create_health_check(compute, HEALTH_CHECK_NAME,
PROJECT_ID)
create_health_check_firewall_rule(compute, FIREWALL_RULE_NAME,
PROJECT_ID)
backend_service_url = create_backend_service(compute,
BACKEND_SERVICE_NAME,
instance_group_url,
health_check_url,
PROJECT_ID)
url_map_url = create_url_map(compute, URL_MAP_NAME, backend_service_url,
SERVICE_HOST, PROJECT_ID)
target_http_proxy_url = create_target_http_proxy(
compute, TARGET_PROXY_NAME, url_map_url, PROJECT_ID)
create_global_forwarding_rule(compute, FORWARDING_RULE_NAME,
SERVICE_PORT, target_http_proxy_url,
PROJECT_ID)
except googleapiclient.errors.HttpError as http_error:
if TOLERATE_GCP_ERRORS:
logger.warning(
'Failed to set up backends: %s. Continuing since '
'--tolerate_gcp_errors=true', http_error)
else:
raise http_error
if instance_group_url is None:
# Look up the instance group URL, which may be unset if we are running
# with --tolerate_gcp_errors=true.
result = compute.instanceGroups().get(
project=PROJECT_ID, zone=ZONE,
instanceGroup=INSTANCE_GROUP_NAME).execute()
instance_group_url = result['selfLink']
wait_for_healthy_backends(compute, PROJECT_ID, BACKEND_SERVICE_NAME,
instance_group_url, WAIT_FOR_BACKEND_SEC)
backends = []
result = compute.instanceGroups().listInstances(
project=PROJECT_ID,
zone=ZONE,
instanceGroup=INSTANCE_GROUP_NAME,
body={
'instanceState': 'ALL'
}).execute()
for item in result['items']:
# listInstances() returns the full URL of the instance, which ends with
# the instance name. compute.instances().get() requires using the
# instance name (not the full URL) to look up instance details, so we
# just extract the name manually.
instance_name = item['instance'].split('/')[-1]
backends.append(instance_name)
client_process = start_xds_client()
if TEST_CASE == 'all':
test_ping_pong(backends, NUM_TEST_RPCS, WAIT_FOR_STATS_SEC)
test_round_robin(backends, NUM_TEST_RPCS, WAIT_FOR_STATS_SEC)
elif TEST_CASE == 'ping_pong':
test_ping_pong(backends, NUM_TEST_RPCS, WAIT_FOR_STATS_SEC)
elif TEST_CASE == 'round_robin':
test_round_robin(backends, NUM_TEST_RPCS, WAIT_FOR_STATS_SEC)
else:
logger.error('Unknown test case: %s', TEST_CASE)
sys.exit(1)
finally:
if client_process:
client_process.terminate()
if not KEEP_GCP_RESOURCES:
logger.info('Cleaning up GCP resources. This may take some time.')
delete_global_forwarding_rule(compute, PROJECT_ID, FORWARDING_RULE_NAME)
delete_target_http_proxy(compute, PROJECT_ID, TARGET_PROXY_NAME)
delete_url_map(compute, PROJECT_ID, URL_MAP_NAME)
delete_backend_service(compute, PROJECT_ID, BACKEND_SERVICE_NAME)
delete_firewall(compute, PROJECT_ID, FIREWALL_RULE_NAME)
delete_health_check(compute, PROJECT_ID, HEALTH_CHECK_NAME)
delete_instance_group(compute, PROJECT_ID, ZONE, INSTANCE_GROUP_NAME)
delete_instance_template(compute, PROJECT_ID, TEMPLATE_NAME)
Loading…
Cancel
Save