Merge branch 'master' into DeexperimentalizeCsmPluginOption

pull/35526/head
Yash Tibrewal 1 year ago
commit 441eccb857
  1. 2
      CMakeLists.txt
  2. 1
      black.toml
  3. 6
      build_autogenerated.yaml
  4. 7
      src/core/ext/filters/client_channel/client_channel.cc
  5. 2
      src/core/ext/filters/client_channel/client_channel_internal.h
  6. 10
      src/core/ext/filters/client_channel/lb_policy/xds/xds_cluster_impl.cc
  7. 32
      src/core/ext/xds/xds_cluster.cc
  8. 2
      src/core/ext/xds/xds_cluster.h
  9. 7
      src/core/lib/channel/call_tracer.cc
  10. 10
      src/core/lib/channel/call_tracer.h
  11. 6
      src/core/tsi/ssl_transport_security.cc
  12. 42
      src/cpp/ext/csm/metadata_exchange.cc
  13. 14
      src/cpp/ext/csm/metadata_exchange.h
  14. 3
      src/cpp/ext/filters/census/open_census_call_tracer.h
  15. 40
      src/cpp/ext/otel/key_value_iterable.h
  16. 7
      src/cpp/ext/otel/otel_call_tracer.h
  17. 20
      src/cpp/ext/otel/otel_client_filter.cc
  18. 26
      src/cpp/ext/otel/otel_plugin.h
  19. 13
      src/cpp/ext/otel/otel_server_call_tracer.cc
  20. 40
      src/proto/grpc/testing/xds/v3/base.proto
  21. 7
      src/proto/grpc/testing/xds/v3/cluster.proto
  22. 4
      src/python/grpcio_observability/grpc_observability/client_call_tracer.h
  23. 1
      test/core/channel/BUILD
  24. 102
      test/core/channel/call_tracer_test.cc
  25. 4
      test/core/client_channel/lb_policy/lb_policy_test_lib.h
  26. 5
      test/core/end2end/tests/http2_stats.cc
  27. 10
      test/core/util/BUILD
  28. 82
      test/core/util/fake_stats_plugin.cc
  29. 192
      test/core/util/fake_stats_plugin.h
  30. 90
      test/core/xds/xds_cluster_resource_type_test.cc
  31. 1
      test/cpp/end2end/xds/BUILD
  32. 39
      test/cpp/end2end/xds/xds_cluster_end2end_test.cc
  33. 39
      test/cpp/ext/csm/metadata_exchange_test.cc
  34. 28
      test/cpp/ext/otel/otel_plugin_test.cc
  35. 59
      test/cpp/ext/otel/otel_test_library.cc
  36. 2
      test/cpp/ext/otel/otel_test_library.h
  37. 3
      tools/distrib/pylint_code.sh
  38. 2
      tools/internal_ci/linux/grpc_build_submodule_at_head.sh
  39. 2
      tools/internal_ci/linux/grpc_xds_k8s_install_test_driver.sh
  40. 2
      tools/internal_ci/linux/grpc_xds_k8s_lb.sh
  41. 2
      tools/internal_ci/linux/grpc_xds_k8s_lb_python.sh
  42. 2
      tools/internal_ci/linux/grpc_xds_k8s_run_xtest.sh
  43. 2
      tools/internal_ci/linux/grpc_xds_url_map.sh
  44. 2
      tools/internal_ci/linux/grpc_xds_url_map_python.sh
  45. 2
      tools/internal_ci/linux/psm-csm.sh
  46. 2
      tools/internal_ci/linux/psm-security-python.sh
  47. 2
      tools/internal_ci/linux/psm-security.sh
  48. 3
      tools/run_tests/run_tests_matrix.py
  49. 5
      tools/run_tests/xds_k8s_test_driver/.gitignore
  50. 457
      tools/run_tests/xds_k8s_test_driver/README.md
  51. 13
      tools/run_tests/xds_k8s_test_driver/bin/__init__.py
  52. 60
      tools/run_tests/xds_k8s_test_driver/bin/black.sh
  53. 59
      tools/run_tests/xds_k8s_test_driver/bin/cleanup.sh
  54. 2
      tools/run_tests/xds_k8s_test_driver/bin/cleanup/README.md
  55. 714
      tools/run_tests/xds_k8s_test_driver/bin/cleanup/cleanup.py
  56. 8
      tools/run_tests/xds_k8s_test_driver/bin/cleanup/keep_xds_interop_resources.json
  57. 95
      tools/run_tests/xds_k8s_test_driver/bin/cleanup_cluster.sh
  58. 29
      tools/run_tests/xds_k8s_test_driver/bin/ensure_venv.sh
  59. 28
      tools/run_tests/xds_k8s_test_driver/bin/freeze.sh
  60. 60
      tools/run_tests/xds_k8s_test_driver/bin/isort.sh
  61. 13
      tools/run_tests/xds_k8s_test_driver/bin/lib/__init__.py
  62. 191
      tools/run_tests/xds_k8s_test_driver/bin/lib/common.py
  63. 271
      tools/run_tests/xds_k8s_test_driver/bin/run_channelz.py
  64. 169
      tools/run_tests/xds_k8s_test_driver/bin/run_ping_pong.py
  65. 310
      tools/run_tests/xds_k8s_test_driver/bin/run_td_setup.py
  66. 162
      tools/run_tests/xds_k8s_test_driver/bin/run_test_client.py
  67. 123
      tools/run_tests/xds_k8s_test_driver/bin/run_test_server.py
  68. 3
      tools/run_tests/xds_k8s_test_driver/config/common-csm.cfg
  69. 11
      tools/run_tests/xds_k8s_test_driver/config/common.cfg
  70. 4
      tools/run_tests/xds_k8s_test_driver/config/gamma.cfg
  71. 9
      tools/run_tests/xds_k8s_test_driver/config/grpc-testing.cfg
  72. 62
      tools/run_tests/xds_k8s_test_driver/config/local-dev.cfg.example
  73. 15
      tools/run_tests/xds_k8s_test_driver/config/url-map.cfg
  74. 13
      tools/run_tests/xds_k8s_test_driver/framework/__init__.py
  75. 182
      tools/run_tests/xds_k8s_test_driver/framework/bootstrap_generator_testcase.py
  76. 58
      tools/run_tests/xds_k8s_test_driver/framework/errors.py
  77. 13
      tools/run_tests/xds_k8s_test_driver/framework/helpers/__init__.py
  78. 79
      tools/run_tests/xds_k8s_test_driver/framework/helpers/datetime.py
  79. 204
      tools/run_tests/xds_k8s_test_driver/framework/helpers/grpc.py
  80. 106
      tools/run_tests/xds_k8s_test_driver/framework/helpers/highlighter.py
  81. 48
      tools/run_tests/xds_k8s_test_driver/framework/helpers/logs.py
  82. 49
      tools/run_tests/xds_k8s_test_driver/framework/helpers/rand.py
  83. 273
      tools/run_tests/xds_k8s_test_driver/framework/helpers/retryers.py
  84. 103
      tools/run_tests/xds_k8s_test_driver/framework/helpers/skips.py
  85. 13
      tools/run_tests/xds_k8s_test_driver/framework/infrastructure/__init__.py
  86. 18
      tools/run_tests/xds_k8s_test_driver/framework/infrastructure/gcp/__init__.py
  87. 542
      tools/run_tests/xds_k8s_test_driver/framework/infrastructure/gcp/api.py
  88. 637
      tools/run_tests/xds_k8s_test_driver/framework/infrastructure/gcp/compute.py
  89. 361
      tools/run_tests/xds_k8s_test_driver/framework/infrastructure/gcp/iam.py
  90. 221
      tools/run_tests/xds_k8s_test_driver/framework/infrastructure/gcp/network_security.py
  91. 461
      tools/run_tests/xds_k8s_test_driver/framework/infrastructure/gcp/network_services.py
  92. 1152
      tools/run_tests/xds_k8s_test_driver/framework/infrastructure/k8s.py
  93. 13
      tools/run_tests/xds_k8s_test_driver/framework/infrastructure/k8s_internal/__init__.py
  94. 142
      tools/run_tests/xds_k8s_test_driver/framework/infrastructure/k8s_internal/k8s_log_collector.py
  95. 133
      tools/run_tests/xds_k8s_test_driver/framework/infrastructure/k8s_internal/k8s_port_forwarder.py
  96. 1118
      tools/run_tests/xds_k8s_test_driver/framework/infrastructure/traffic_director.py
  97. 22
      tools/run_tests/xds_k8s_test_driver/framework/infrastructure/traffic_director_gamma.py
  98. 14
      tools/run_tests/xds_k8s_test_driver/framework/rpc/__init__.py
  99. 117
      tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc.py
  100. 273
      tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_channelz.py
  101. Some files were not shown because too many files have changed in this diff Show More

2
CMakeLists.txt generated

@ -7809,6 +7809,7 @@ if(gRPC_BUILD_TESTS)
add_executable(call_tracer_test
test/core/channel/call_tracer_test.cc
test/core/util/fake_stats_plugin.cc
)
target_compile_features(call_tracer_test PUBLIC cxx_std_14)
target_include_directories(call_tracer_test
@ -27043,6 +27044,7 @@ if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX)
${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/xds/v3/string.grpc.pb.cc
${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/xds/v3/string.pb.h
${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/xds/v3/string.grpc.pb.h
test/core/util/fake_stats_plugin.cc
test/cpp/end2end/connection_attempt_injector.cc
test/cpp/end2end/test_service_impl.cc
test/cpp/end2end/xds/xds_cluster_end2end_test.cc

@ -30,7 +30,6 @@ line_length = 80
src_paths = [
"examples/python/data_transmission",
"examples/python/async_streaming",
"tools/run_tests/xds_k8s_test_driver",
"src/python/grpcio_tests",
"tools/run_tests",
]

@ -6418,9 +6418,11 @@ targets:
gtest: true
build: test
language: c++
headers: []
headers:
- test/core/util/fake_stats_plugin.h
src:
- test/core/channel/call_tracer_test.cc
- test/core/util/fake_stats_plugin.cc
deps:
- gtest
- grpc_test_util
@ -18016,6 +18018,7 @@ targets:
run: false
language: c++
headers:
- test/core/util/fake_stats_plugin.h
- test/core/util/scoped_env_var.h
- test/cpp/end2end/connection_attempt_injector.h
- test/cpp/end2end/counted_service.h
@ -18056,6 +18059,7 @@ targets:
- src/proto/grpc/testing/xds/v3/route.proto
- src/proto/grpc/testing/xds/v3/router.proto
- src/proto/grpc/testing/xds/v3/string.proto
- test/core/util/fake_stats_plugin.cc
- test/cpp/end2end/connection_attempt_injector.cc
- test/cpp/end2end/test_service_impl.cc
- test/cpp/end2end/xds/xds_cluster_end2end_test.cc

@ -2603,6 +2603,8 @@ class ClientChannel::LoadBalancedCall::LbCallState
ServiceConfigCallData::CallAttributeInterface* GetCallAttribute(
UniqueTypeName type) const override;
ClientCallTracer::CallAttemptTracer* GetCallAttemptTracer() const override;
private:
LoadBalancedCall* lb_call_;
};
@ -2694,6 +2696,11 @@ ClientChannel::LoadBalancedCall::LbCallState::GetCallAttribute(
return service_config_call_data->GetCallAttribute(type);
}
ClientCallTracer::CallAttemptTracer*
ClientChannel::LoadBalancedCall::LbCallState::GetCallAttemptTracer() const {
return lb_call_->call_attempt_tracer();
}
//
// ClientChannel::LoadBalancedCall::BackendMetricAccessor
//

@ -25,6 +25,7 @@
#include <grpc/support/log.h>
#include "src/core/lib/channel/call_tracer.h"
#include "src/core/lib/channel/context.h"
#include "src/core/lib/gprpp/unique_type_name.h"
#include "src/core/lib/load_balancing/lb_policy.h"
@ -49,6 +50,7 @@ class ClientChannelLbCallState : public LoadBalancingPolicy::CallState {
public:
virtual ServiceConfigCallData::CallAttributeInterface* GetCallAttribute(
UniqueTypeName type) const = 0;
virtual ClientCallTracer::CallAttemptTracer* GetCallAttemptTracer() const = 0;
};
// Internal type for ServiceConfigCallData. Handles call commits.

@ -37,6 +37,7 @@
#include <grpc/impl/connectivity_state.h>
#include <grpc/support/log.h>
#include "src/core/ext/filters/client_channel/client_channel_internal.h"
#include "src/core/ext/filters/client_channel/lb_policy/backend_metric_data.h"
#include "src/core/ext/filters/client_channel/lb_policy/child_policy_handler.h"
#include "src/core/ext/filters/client_channel/lb_policy/xds/xds_channel_args.h"
@ -76,6 +77,8 @@ TraceFlag grpc_xds_cluster_impl_lb_trace(false, "xds_cluster_impl_lb");
namespace {
using OptionalLabelComponent =
ClientCallTracer::CallAttemptTracer::OptionalLabelComponent;
using XdsConfig = XdsDependencyManager::XdsConfig;
//
@ -215,6 +218,7 @@ class XdsClusterImplLb : public LoadBalancingPolicy {
RefCountedPtr<CircuitBreakerCallCounterMap::CallCounter> call_counter_;
uint32_t max_concurrent_requests_;
std::shared_ptr<std::map<std::string, std::string>> service_labels_;
RefCountedPtr<XdsEndpointResource::DropConfig> drop_config_;
RefCountedPtr<XdsClusterDropStats> drop_stats_;
RefCountedPtr<SubchannelPicker> picker_;
@ -358,6 +362,7 @@ XdsClusterImplLb::Picker::Picker(XdsClusterImplLb* xds_cluster_impl_lb,
: call_counter_(xds_cluster_impl_lb->call_counter_),
max_concurrent_requests_(
xds_cluster_impl_lb->cluster_resource_->max_concurrent_requests),
service_labels_(xds_cluster_impl_lb->cluster_resource_->telemetry_labels),
drop_config_(xds_cluster_impl_lb->drop_config_),
drop_stats_(xds_cluster_impl_lb->drop_stats_),
picker_(std::move(picker)) {
@ -369,6 +374,11 @@ XdsClusterImplLb::Picker::Picker(XdsClusterImplLb* xds_cluster_impl_lb,
LoadBalancingPolicy::PickResult XdsClusterImplLb::Picker::Pick(
LoadBalancingPolicy::PickArgs args) {
auto* call_state = static_cast<ClientChannelLbCallState*>(args.call_state);
if (call_state->GetCallAttemptTracer() != nullptr) {
call_state->GetCallAttemptTracer()->AddOptionalLabels(
OptionalLabelComponent::kXdsServiceLabels, service_labels_);
}
// Handle EDS drops.
const std::string* drop_category;
if (drop_config_ != nullptr && drop_config_->ShouldDrop(&drop_category)) {

@ -48,6 +48,7 @@
#include "envoy/extensions/upstreams/http/v3/http_protocol_options.upb.h"
#include "google/protobuf/any.upb.h"
#include "google/protobuf/duration.upb.h"
#include "google/protobuf/struct.upb.h"
#include "google/protobuf/wrappers.upb.h"
#include "upb/base/string_view.h"
#include "upb/text/encode.h"
@ -703,6 +704,37 @@ absl::StatusOr<std::shared_ptr<const XdsClusterResource>> CdsResourceParse(
cds_update->override_host_statuses.Add(
XdsHealthStatus(XdsHealthStatus::kHealthy));
}
// Record telemetry labels (if any).
const envoy_config_core_v3_Metadata* metadata =
envoy_config_cluster_v3_Cluster_metadata(cluster);
if (metadata != nullptr) {
google_protobuf_Struct* telemetry_labels_struct;
if (envoy_config_core_v3_Metadata_filter_metadata_get(
metadata,
StdStringToUpbString(
absl::string_view("com.google.csm.telemetry_labels")),
&telemetry_labels_struct)) {
auto telemetry_labels =
std::make_shared<std::map<std::string, std::string>>();
size_t iter = kUpb_Map_Begin;
const google_protobuf_Struct_FieldsEntry* fields_entry;
while ((fields_entry = google_protobuf_Struct_fields_next(
telemetry_labels_struct, &iter)) != nullptr) {
// Adds any entry whose value is a string to telemetry_labels.
const google_protobuf_Value* value =
google_protobuf_Struct_FieldsEntry_value(fields_entry);
if (google_protobuf_Value_has_string_value(value)) {
telemetry_labels->emplace(
UpbStringToStdString(
google_protobuf_Struct_FieldsEntry_key(fields_entry)),
UpbStringToStdString(google_protobuf_Value_string_value(value)));
}
}
if (!telemetry_labels->empty()) {
cds_update->telemetry_labels = std::move(telemetry_labels);
}
}
}
// Return result.
if (!errors.ok()) {
return errors.status(absl::StatusCode::kInvalidArgument,

@ -101,6 +101,8 @@ struct XdsClusterResource : public XdsResourceType::ResourceData {
XdsHealthStatusSet override_host_statuses;
std::shared_ptr<std::map<std::string, std::string>> telemetry_labels;
bool operator==(const XdsClusterResource& other) const {
return type == other.type && lb_policy_config == other.lb_policy_config &&
lrs_load_reporting_server == other.lrs_load_reporting_server &&

@ -146,6 +146,13 @@ class DelegatingClientCallTracer : public ClientCallTracer {
std::shared_ptr<TcpTracerInterface> StartNewTcpTrace() override {
return nullptr;
}
void AddOptionalLabels(
OptionalLabelComponent component,
std::shared_ptr<std::map<std::string, std::string>> labels) override {
for (auto* tracer : tracers_) {
tracer->AddOptionalLabels(component, labels);
}
}
std::string TraceId() override { return tracers_[0]->TraceId(); }
std::string SpanId() override { return tracers_[0]->SpanId(); }
bool IsSampled() override { return tracers_[0]->IsSampled(); }

@ -128,6 +128,11 @@ class ClientCallTracer : public CallTracerAnnotationInterface {
// as transparent retry attempts.)
class CallAttemptTracer : public CallTracerInterface {
public:
enum class OptionalLabelComponent : std::uint8_t {
kXdsServiceLabels = 0,
kSize = 1, // keep last
};
~CallAttemptTracer() override {}
// TODO(yashykt): The following two methods `RecordReceivedTrailingMetadata`
// and `RecordEnd` should be moved into CallTracerInterface.
@ -140,6 +145,11 @@ class ClientCallTracer : public CallTracerAnnotationInterface {
// Should be the last API call to the object. Once invoked, the tracer
// library is free to destroy the object.
virtual void RecordEnd(const gpr_timespec& latency) = 0;
// Adds optional labels to be reported by the underlying tracer in a call.
virtual void AddOptionalLabels(
OptionalLabelComponent component,
std::shared_ptr<std::map<std::string, std::string>> labels) = 0;
};
~ClientCallTracer() override {}

@ -2074,6 +2074,9 @@ tsi_result tsi_create_ssl_client_handshaker_factory_with_options(
ssl_context = SSL_CTX_new(TLS_method());
#else
ssl_context = SSL_CTX_new(TLSv1_2_method());
#endif
#if OPENSSL_VERSION_NUMBER >= 0x10101000
SSL_CTX_set_options(ssl_context, SSL_OP_NO_RENEGOTIATION);
#endif
if (ssl_context == nullptr) {
grpc_core::LogSslErrorStack();
@ -2289,6 +2292,9 @@ tsi_result tsi_create_ssl_server_handshaker_factory_with_options(
impl->ssl_contexts[i] = SSL_CTX_new(TLS_method());
#else
impl->ssl_contexts[i] = SSL_CTX_new(TLSv1_2_method());
#endif
#if OPENSSL_VERSION_NUMBER >= 0x10101000
SSL_CTX_set_options(impl->ssl_contexts[i], SSL_OP_NO_RENEGOTIATION);
#endif
if (impl->ssl_contexts[i] == nullptr) {
grpc_core::LogSslErrorStack();

@ -42,6 +42,7 @@
#include <grpc/slice.h>
#include "src/core/lib/channel/call_tracer.h"
#include "src/core/lib/gprpp/env.h"
#include "src/core/lib/iomgr/error.h"
#include "src/core/lib/iomgr/load_file.h"
@ -49,10 +50,14 @@
#include "src/core/lib/json/json_object_loader.h"
#include "src/core/lib/json/json_reader.h"
#include "src/core/lib/slice/slice_internal.h"
#include "src/cpp/ext/otel/key_value_iterable.h"
namespace grpc {
namespace internal {
using OptionalLabelComponent =
grpc_core::ClientCallTracer::CallAttemptTracer::OptionalLabelComponent;
namespace {
// The keys that will be used in the Metadata Exchange between local and remote.
@ -427,5 +432,42 @@ void ServiceMeshLabelsInjector::AddLabels(
serialized_labels_to_send_.Ref());
}
bool ServiceMeshLabelsInjector::AddOptionalLabels(
absl::Span<const std::shared_ptr<std::map<std::string, std::string>>>
optional_labels_span,
opentelemetry::nostd::function_ref<
bool(opentelemetry::nostd::string_view,
opentelemetry::common::AttributeValue)>
callback) const {
// According to the CSM Observability Metric spec, if the control plane fails
// to provide these labels, the client will set their values to "unknown".
// These default values are set below.
absl::string_view service_name = "unknown";
absl::string_view service_namespace = "unknown";
// Performs JSON label name format to CSM Observability Metric spec format
// conversion.
if (optional_labels_span.size() >
static_cast<size_t>(OptionalLabelComponent::kXdsServiceLabels)) {
const auto& optional_labels = optional_labels_span[static_cast<size_t>(
OptionalLabelComponent::kXdsServiceLabels)];
if (optional_labels != nullptr) {
auto it = optional_labels->find("service_name");
if (it != optional_labels->end()) service_name = it->second;
it = optional_labels->find("service_namespace");
if (it != optional_labels->end()) service_namespace = it->second;
}
}
return callback("csm.service_name",
AbslStrViewToOpenTelemetryStrView(service_name)) &&
callback("csm.service_namespace_name",
AbslStrViewToOpenTelemetryStrView(service_namespace));
}
size_t ServiceMeshLabelsInjector::GetOptionalLabelsSize(
absl::Span<const std::shared_ptr<std::map<std::string, std::string>>>)
const {
return 2;
}
} // namespace internal
} // namespace grpc

@ -50,6 +50,20 @@ class ServiceMeshLabelsInjector : public LabelsInjector {
void AddLabels(grpc_metadata_batch* outgoing_initial_metadata,
LabelsIterable* labels_from_incoming_metadata) const override;
// Add optional labels to the traced calls.
bool AddOptionalLabels(
absl::Span<const std::shared_ptr<std::map<std::string, std::string>>>
optional_labels_span,
opentelemetry::nostd::function_ref<
bool(opentelemetry::nostd::string_view,
opentelemetry::common::AttributeValue)>
callback) const override;
// Gets the size of the actual optional labels.
size_t GetOptionalLabelsSize(
absl::Span<const std::shared_ptr<std::map<std::string, std::string>>>
optional_labels_span) const override;
private:
std::vector<std::pair<absl::string_view, std::string>> local_labels_;
grpc_core::Slice serialized_labels_to_send_;

@ -103,6 +103,9 @@ class OpenCensusCallTracer : public grpc_core::ClientCallTracer {
void RecordAnnotation(absl::string_view annotation) override;
void RecordAnnotation(const Annotation& annotation) override;
std::shared_ptr<grpc_core::TcpTracerInterface> StartNewTcpTrace() override;
void AddOptionalLabels(
OptionalLabelComponent,
std::shared_ptr<std::map<std::string, std::string>>) override {}
experimental::CensusContext* context() { return &context_; }

@ -53,11 +53,16 @@ class KeyValueIterable : public opentelemetry::common::KeyValueIterable {
const std::vector<std::unique_ptr<LabelsIterable>>&
injected_labels_from_plugin_options,
absl::Span<const std::pair<absl::string_view, absl::string_view>>
additional_labels)
additional_labels,
const ActivePluginOptionsView* active_plugin_options_view,
absl::Span<const std::shared_ptr<std::map<std::string, std::string>>>
optional_labels_span)
: injected_labels_iterable_(injected_labels_iterable),
injected_labels_from_plugin_options_(
injected_labels_from_plugin_options),
additional_labels_(additional_labels) {}
additional_labels_(additional_labels),
active_plugin_options_view_(active_plugin_options_view),
optional_labels_(optional_labels_span) {}
bool ForEachKeyValue(opentelemetry::nostd::function_ref<
bool(opentelemetry::nostd::string_view,
@ -72,6 +77,21 @@ class KeyValueIterable : public opentelemetry::common::KeyValueIterable {
}
}
}
if (OpenTelemetryPluginState().labels_injector != nullptr &&
!OpenTelemetryPluginState().labels_injector->AddOptionalLabels(
optional_labels_, callback)) {
return false;
}
if (active_plugin_options_view_ != nullptr &&
!active_plugin_options_view_->ForEach(
[callback, this](
const InternalOpenTelemetryPluginOption& plugin_option,
size_t /*index*/) {
return plugin_option.labels_injector()->AddOptionalLabels(
optional_labels_, callback);
})) {
return false;
}
for (const auto& plugin_option_injected_iterable :
injected_labels_from_plugin_options_) {
if (plugin_option_injected_iterable != nullptr) {
@ -104,6 +124,19 @@ class KeyValueIterable : public opentelemetry::common::KeyValueIterable {
}
}
size += additional_labels_.size();
if (OpenTelemetryPluginState().labels_injector != nullptr) {
size += OpenTelemetryPluginState().labels_injector->GetOptionalLabelsSize(
optional_labels_);
}
if (active_plugin_options_view_ != nullptr) {
active_plugin_options_view_->ForEach(
[&size, this](const InternalOpenTelemetryPluginOption& plugin_option,
size_t /*index*/) {
size += plugin_option.labels_injector()->GetOptionalLabelsSize(
optional_labels_);
return true;
});
}
return size;
}
@ -113,6 +146,9 @@ class KeyValueIterable : public opentelemetry::common::KeyValueIterable {
injected_labels_from_plugin_options_;
absl::Span<const std::pair<absl::string_view, absl::string_view>>
additional_labels_;
const ActivePluginOptionsView* active_plugin_options_view_;
absl::Span<const std::shared_ptr<std::map<std::string, std::string>>>
optional_labels_;
};
} // namespace internal

@ -91,6 +91,9 @@ class OpenTelemetryCallTracer : public grpc_core::ClientCallTracer {
void RecordAnnotation(absl::string_view /*annotation*/) override;
void RecordAnnotation(const Annotation& /*annotation*/) override;
std::shared_ptr<grpc_core::TcpTracerInterface> StartNewTcpTrace() override;
void AddOptionalLabels(OptionalLabelComponent component,
std::shared_ptr<std::map<std::string, std::string>>
optional_labels) override;
private:
const OpenTelemetryCallTracer* parent_;
@ -98,6 +101,10 @@ class OpenTelemetryCallTracer : public grpc_core::ClientCallTracer {
// Start time (for measuring latency).
absl::Time start_time_;
std::unique_ptr<LabelsIterable> injected_labels_;
// The indices of the array correspond to the OptionalLabelComponent enum.
std::array<std::shared_ptr<std::map<std::string, std::string>>,
static_cast<size_t>(OptionalLabelComponent::kSize)>
optional_labels_array_;
std::vector<std::unique_ptr<LabelsIterable>>
injected_labels_from_plugin_options_;
};

@ -132,7 +132,9 @@ OpenTelemetryCallTracer::OpenTelemetryCallAttemptTracer::
// avoid recording a subset of injected labels here.
OpenTelemetryPluginState().client.attempt.started->Add(
1, KeyValueIterable(/*injected_labels_iterable=*/nullptr, {},
additional_labels));
additional_labels,
/*active_plugin_options_view=*/nullptr,
/*optional_labels_span=*/{}));
}
}
@ -150,6 +152,7 @@ void OpenTelemetryCallTracer::OpenTelemetryCallAttemptTracer::
injected_labels_from_plugin_options_.push_back(
labels_injector->GetLabels(recv_initial_metadata));
}
return true;
});
}
@ -166,6 +169,7 @@ void OpenTelemetryCallTracer::OpenTelemetryCallAttemptTracer::
if (labels_injector != nullptr) {
labels_injector->AddLabels(send_initial_metadata, nullptr);
}
return true;
});
}
@ -206,9 +210,10 @@ void OpenTelemetryCallTracer::OpenTelemetryCallAttemptTracer::
{OpenTelemetryStatusKey(),
grpc_status_code_to_string(
static_cast<grpc_status_code>(status.code()))}}};
KeyValueIterable labels(injected_labels_.get(),
injected_labels_from_plugin_options_,
additional_labels);
KeyValueIterable labels(
injected_labels_.get(), injected_labels_from_plugin_options_,
additional_labels, &parent_->parent_->active_plugin_options_view(),
optional_labels_array_);
if (OpenTelemetryPluginState().client.attempt.duration != nullptr) {
OpenTelemetryPluginState().client.attempt.duration->Record(
absl::ToDoubleSeconds(absl::Now() - start_time_), labels,
@ -262,6 +267,13 @@ OpenTelemetryCallTracer::OpenTelemetryCallAttemptTracer::StartNewTcpTrace() {
return nullptr;
}
void OpenTelemetryCallTracer::OpenTelemetryCallAttemptTracer::AddOptionalLabels(
OptionalLabelComponent component,
std::shared_ptr<std::map<std::string, std::string>> optional_labels) {
optional_labels_array_[static_cast<std::size_t>(component)] =
std::move(optional_labels);
}
//
// OpenTelemetryCallTracer
//

@ -79,6 +79,23 @@ class LabelsInjector {
virtual void AddLabels(
grpc_metadata_batch* outgoing_initial_metadata,
LabelsIterable* labels_from_incoming_metadata) const = 0;
// Adds optional labels to the traced calls. Each entry in the span
// corresponds to the CallAttemptTracer::OptionalLabelComponent enum. Returns
// false when callback returns false.
virtual bool AddOptionalLabels(
absl::Span<const std::shared_ptr<std::map<std::string, std::string>>>
optional_labels_span,
opentelemetry::nostd::function_ref<
bool(opentelemetry::nostd::string_view,
opentelemetry::common::AttributeValue)>
callback) const = 0;
// Gets the actual size of the optional labels that the Plugin is going to
// produce through the AddOptionalLabels method.
virtual size_t GetOptionalLabelsSize(
absl::Span<const std::shared_ptr<std::map<std::string, std::string>>>
optional_labels_span) const = 0;
};
class InternalOpenTelemetryPluginOption
@ -223,14 +240,17 @@ class ActivePluginOptionsView {
});
}
void ForEach(
absl::FunctionRef<void(const InternalOpenTelemetryPluginOption&, size_t)>
bool ForEach(
absl::FunctionRef<bool(const InternalOpenTelemetryPluginOption&, size_t)>
func) const {
for (size_t i = 0; i < OpenTelemetryPluginState().plugin_options.size();
++i) {
const auto& plugin_option = OpenTelemetryPluginState().plugin_options[i];
if (active_mask_[i]) func(*plugin_option, i);
if (active_mask_[i] && !func(*plugin_option, i)) {
return false;
}
}
return true;
}
private:

@ -95,6 +95,7 @@ class OpenTelemetryServerCallTracer : public grpc_core::ServerCallTracer {
send_initial_metadata,
injected_labels_from_plugin_options_[index].get());
}
return true;
});
}
@ -186,6 +187,7 @@ void OpenTelemetryServerCallTracer::RecordReceivedInitialMetadata(
injected_labels_from_plugin_options_[index] =
labels_injector->GetLabels(recv_initial_metadata);
}
return true;
});
registered_method_ =
recv_initial_metadata->get(grpc_core::GrpcRegisteredMethod())
@ -197,7 +199,8 @@ void OpenTelemetryServerCallTracer::RecordReceivedInitialMetadata(
// avoid recording a subset of injected labels here.
OpenTelemetryPluginState().server.call.started->Add(
1, KeyValueIterable(/*injected_labels_iterable=*/nullptr, {},
additional_labels));
additional_labels,
/*active_plugin_options_view=*/nullptr, {}));
}
}
@ -215,9 +218,11 @@ void OpenTelemetryServerCallTracer::RecordEnd(
{{OpenTelemetryMethodKey(), MethodForStats()},
{OpenTelemetryStatusKey(),
grpc_status_code_to_string(final_info->final_status)}}};
KeyValueIterable labels(injected_labels_.get(),
injected_labels_from_plugin_options_,
additional_labels);
// Currently we do not have any optional labels on the server side.
KeyValueIterable labels(
injected_labels_.get(), injected_labels_from_plugin_options_,
additional_labels,
/*active_plugin_options_view=*/nullptr, /*optional_labels_span=*/{});
if (OpenTelemetryPluginState().server.call.duration != nullptr) {
OpenTelemetryPluginState().server.call.duration->Record(
absl::ToDoubleSeconds(elapsed_time_), labels,

@ -129,3 +129,43 @@ message TransportSocket {
google.protobuf.Any typed_config = 3;
}
}
// Metadata provides additional inputs to filters based on matched listeners,
// filter chains, routes and endpoints. It is structured as a map, usually from
// filter name (in reverse DNS format) to metadata specific to the filter. Metadata
// key-values for a filter are merged as connection and request handling occurs,
// with later values for the same key overriding earlier values.
//
// An example use of metadata is providing additional values to
// http_connection_manager in the envoy.http_connection_manager.access_log
// namespace.
//
// Another example use of metadata is to per service config info in cluster metadata, which may get
// consumed by multiple filters.
//
// For load balancing, Metadata provides a means to subset cluster endpoints.
// Endpoints have a Metadata object associated and routes contain a Metadata
// object to match against. There are some well defined metadata used today for
// this purpose:
//
// * ``{"envoy.lb": {"canary": <bool> }}`` This indicates the canary status of an
// endpoint and is also used during header processing
// (x-envoy-upstream-canary) and for stats purposes.
// [#next-major-version: move to type/metadata/v2]
message Metadata {
// Key is the reverse DNS filter name, e.g. com.acme.widget. The ``envoy.*``
// namespace is reserved for Envoy's built-in filters.
// If both ``filter_metadata`` and
// :ref:`typed_filter_metadata <envoy_v3_api_field_config.core.v3.Metadata.typed_filter_metadata>`
// fields are present in the metadata with same keys,
// only ``typed_filter_metadata`` field will be parsed.
map<string, google.protobuf.Struct> filter_metadata = 1;
// Key is the reverse DNS filter name, e.g. com.acme.widget. The ``envoy.*``
// namespace is reserved for Envoy's built-in filters.
// The value is encoded as google.protobuf.Any.
// If both :ref:`filter_metadata <envoy_v3_api_field_config.core.v3.Metadata.filter_metadata>`
// and ``typed_filter_metadata`` fields are present in the metadata with same keys,
// only ``typed_filter_metadata`` field will be parsed.
map<string, google.protobuf.Any> typed_filter_metadata = 2;
}

@ -252,6 +252,13 @@ message Cluster {
// from the LRS stream here.]
core.v3.ConfigSource lrs_server = 42;
// The Metadata field can be used to provide additional information about the
// cluster. It can be used for stats, logging, and varying filter behavior.
// Fields should use reverse DNS notation to denote which entity within Envoy
// will need the information. For instance, if the metadata is intended for
// the Router filter, the filter name should be specified as ``envoy.filters.http.router``.
core.v3.Metadata metadata = 25;
core.v3.TypedExtensionConfig upstream_config = 48;
}

@ -73,6 +73,10 @@ class PythonOpenCensusCallTracer : public grpc_core::ClientCallTracer {
void RecordAnnotation(absl::string_view annotation) override;
void RecordAnnotation(const Annotation& annotation) override;
std::shared_ptr<grpc_core::TcpTracerInterface> StartNewTcpTrace() override;
void AddOptionalLabels(
OptionalLabelComponent /*component*/,
std::shared_ptr<std::map<std::string, std::string>> /*labels*/)
override {}
private:
// Maximum size of trace context is sent on the wire.

@ -27,6 +27,7 @@ grpc_cc_test(
uses_polling = False,
deps = [
"//:grpc",
"//test/core/util:fake_stats_plugin",
"//test/core/util:grpc_test_util",
],
)

@ -18,7 +18,6 @@
#include "src/core/lib/channel/call_tracer.h"
#include <memory>
#include <vector>
#include "gtest/gtest.h"
@ -26,115 +25,16 @@
#include <grpc/event_engine/memory_allocator.h>
#include <grpc/grpc.h>
#include "src/core/lib/channel/tcp_tracer.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/promise/context.h"
#include "src/core/lib/resource_quota/memory_quota.h"
#include "src/core/lib/resource_quota/resource_quota.h"
#include "test/core/util/fake_stats_plugin.h"
#include "test/core/util/test_config.h"
namespace grpc_core {
namespace {
class FakeClientCallTracer : public ClientCallTracer {
public:
class FakeClientCallAttemptTracer
: public ClientCallTracer::CallAttemptTracer {
public:
explicit FakeClientCallAttemptTracer(
std::vector<std::string>* annotation_logger)
: annotation_logger_(annotation_logger) {}
~FakeClientCallAttemptTracer() override {}
void RecordSendInitialMetadata(
grpc_metadata_batch* /*send_initial_metadata*/) override {}
void RecordSendTrailingMetadata(
grpc_metadata_batch* /*send_trailing_metadata*/) override {}
void RecordSendMessage(const SliceBuffer& /*send_message*/) override {}
void RecordSendCompressedMessage(
const SliceBuffer& /*send_compressed_message*/) override {}
void RecordReceivedInitialMetadata(
grpc_metadata_batch* /*recv_initial_metadata*/) override {}
void RecordReceivedMessage(const SliceBuffer& /*recv_message*/) override {}
void RecordReceivedDecompressedMessage(
const SliceBuffer& /*recv_decompressed_message*/) override {}
void RecordCancel(grpc_error_handle /*cancel_error*/) override {}
void RecordReceivedTrailingMetadata(
absl::Status /*status*/,
grpc_metadata_batch* /*recv_trailing_metadata*/,
const grpc_transport_stream_stats* /*transport_stream_stats*/)
override {}
void RecordEnd(const gpr_timespec& /*latency*/) override { delete this; }
void RecordAnnotation(absl::string_view annotation) override {
annotation_logger_->push_back(std::string(annotation));
}
void RecordAnnotation(const Annotation& /*annotation*/) override {}
std::shared_ptr<TcpTracerInterface> StartNewTcpTrace() override {
return nullptr;
}
std::string TraceId() override { return ""; }
std::string SpanId() override { return ""; }
bool IsSampled() override { return false; }
private:
std::vector<std::string>* annotation_logger_;
};
explicit FakeClientCallTracer(std::vector<std::string>* annotation_logger)
: annotation_logger_(annotation_logger) {}
~FakeClientCallTracer() override {}
CallAttemptTracer* StartNewAttempt(bool /*is_transparent_retry*/) override {
return GetContext<Arena>()->ManagedNew<FakeClientCallAttemptTracer>(
annotation_logger_);
}
void RecordAnnotation(absl::string_view annotation) override {
annotation_logger_->push_back(std::string(annotation));
}
void RecordAnnotation(const Annotation& /*annotation*/) override {}
std::string TraceId() override { return ""; }
std::string SpanId() override { return ""; }
bool IsSampled() override { return false; }
private:
std::vector<std::string>* annotation_logger_;
};
class FakeServerCallTracer : public ServerCallTracer {
public:
explicit FakeServerCallTracer(std::vector<std::string>* annotation_logger)
: annotation_logger_(annotation_logger) {}
~FakeServerCallTracer() override {}
void RecordSendInitialMetadata(
grpc_metadata_batch* /*send_initial_metadata*/) override {}
void RecordSendTrailingMetadata(
grpc_metadata_batch* /*send_trailing_metadata*/) override {}
void RecordSendMessage(const SliceBuffer& /*send_message*/) override {}
void RecordSendCompressedMessage(
const SliceBuffer& /*send_compressed_message*/) override {}
void RecordReceivedInitialMetadata(
grpc_metadata_batch* /*recv_initial_metadata*/) override {}
void RecordReceivedMessage(const SliceBuffer& /*recv_message*/) override {}
void RecordReceivedDecompressedMessage(
const SliceBuffer& /*recv_decompressed_message*/) override {}
void RecordCancel(grpc_error_handle /*cancel_error*/) override {}
void RecordReceivedTrailingMetadata(
grpc_metadata_batch* /*recv_trailing_metadata*/) override {}
void RecordEnd(const grpc_call_final_info* /*final_info*/) override {}
void RecordAnnotation(absl::string_view annotation) override {
annotation_logger_->push_back(std::string(annotation));
}
void RecordAnnotation(const Annotation& /*annotation*/) override {}
std::shared_ptr<TcpTracerInterface> StartNewTcpTrace() override {
return nullptr;
}
std::string TraceId() override { return ""; }
std::string SpanId() override { return ""; }
bool IsSampled() override { return false; }
private:
std::vector<std::string>* annotation_logger_;
};
class CallTracerTest : public ::testing::Test {
protected:
void SetUp() override {

@ -671,6 +671,10 @@ class LoadBalancingPolicyTest : public ::testing::Test {
return nullptr;
}
ClientCallTracer::CallAttemptTracer* GetCallAttemptTracer() const override {
return nullptr;
}
std::vector<void*> allocations_;
std::map<UniqueTypeName, ServiceConfigCallData::CallAttributeInterface*>
attributes_;

@ -97,6 +97,11 @@ class FakeCallTracer : public ClientCallTracer {
void RecordAnnotation(absl::string_view /*annotation*/) override {}
void RecordAnnotation(const Annotation& /*annotation*/) override {}
void AddOptionalLabels(
OptionalLabelComponent /*component*/,
std::shared_ptr<std::map<std::string, std::string>> /*labels*/)
override {}
static grpc_transport_stream_stats transport_stream_stats() {
MutexLock lock(g_mu);
return transport_stream_stats_;

@ -491,3 +491,13 @@ grpc_cc_library(
"//src/core:resource_quota",
],
)
grpc_cc_library(
name = "fake_stats_plugin",
srcs = ["fake_stats_plugin.cc"],
hdrs = ["fake_stats_plugin.h"],
deps = [
"//:grpc",
"//src/core:examine_stack",
],
)

@ -0,0 +1,82 @@
// Copyright 2023 The gRPC Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "test/core/util/fake_stats_plugin.h"
#include "src/core/lib/config/core_configuration.h"
namespace grpc_core {
class FakeStatsClientFilter : public ChannelFilter {
public:
static const grpc_channel_filter kFilter;
static absl::StatusOr<FakeStatsClientFilter> Create(
const ChannelArgs& /*args*/, ChannelFilter::Args /*filter_args*/);
ArenaPromise<ServerMetadataHandle> MakeCallPromise(
CallArgs call_args, NextPromiseFactory next_promise_factory) override;
private:
explicit FakeStatsClientFilter(
FakeClientCallTracerFactory* fake_client_call_tracer_factory);
FakeClientCallTracerFactory* const fake_client_call_tracer_factory_;
};
const grpc_channel_filter FakeStatsClientFilter::kFilter =
MakePromiseBasedFilter<FakeStatsClientFilter, FilterEndpoint::kClient>(
"fake_stats_client");
absl::StatusOr<FakeStatsClientFilter> FakeStatsClientFilter::Create(
const ChannelArgs& args, ChannelFilter::Args /*filter_args*/) {
auto* fake_client_call_tracer_factory =
args.GetPointer<FakeClientCallTracerFactory>(
GRPC_ARG_INJECT_FAKE_CLIENT_CALL_TRACER_FACTORY);
GPR_ASSERT(fake_client_call_tracer_factory != nullptr);
return FakeStatsClientFilter(fake_client_call_tracer_factory);
}
ArenaPromise<ServerMetadataHandle> FakeStatsClientFilter::MakeCallPromise(
CallArgs call_args, NextPromiseFactory next_promise_factory) {
FakeClientCallTracer* client_call_tracer =
fake_client_call_tracer_factory_->CreateFakeClientCallTracer();
if (client_call_tracer != nullptr) {
auto* call_context = GetContext<grpc_call_context_element>();
call_context[GRPC_CONTEXT_CALL_TRACER_ANNOTATION_INTERFACE].value =
client_call_tracer;
call_context[GRPC_CONTEXT_CALL_TRACER_ANNOTATION_INTERFACE].destroy =
nullptr;
}
return next_promise_factory(std::move(call_args));
}
FakeStatsClientFilter::FakeStatsClientFilter(
FakeClientCallTracerFactory* fake_client_call_tracer_factory)
: fake_client_call_tracer_factory_(fake_client_call_tracer_factory) {}
void RegisterFakeStatsPlugin() {
CoreConfiguration::RegisterBuilder(
[](CoreConfiguration::Builder* builder) mutable {
builder->channel_init()
->RegisterFilter(GRPC_CLIENT_CHANNEL,
&FakeStatsClientFilter::kFilter)
.If([](const ChannelArgs& args) {
return args.GetPointer<FakeClientCallTracerFactory>(
GRPC_ARG_INJECT_FAKE_CLIENT_CALL_TRACER_FACTORY) !=
nullptr;
});
});
}
} // namespace grpc_core

@ -0,0 +1,192 @@
// Copyright 2023 The gRPC Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef GRPC_TEST_CORE_UTIL_FAKE_STATS_PLUGIN_H
#define GRPC_TEST_CORE_UTIL_FAKE_STATS_PLUGIN_H
#include <memory>
#include <string>
#include <vector>
#include "src/core/lib/channel/call_tracer.h"
#include "src/core/lib/channel/promise_based_filter.h"
#include "src/core/lib/channel/tcp_tracer.h"
namespace grpc_core {
// Registers a FakeStatsClientFilter as a client channel filter if there is a
// FakeClientCallTracerFactory in the channel args. This filter will use the
// FakeClientCallTracerFactory to create and inject a FakeClientCallTracer into
// the call context.
// Example usage:
// RegisterFakeStatsPlugin(); // before grpc_init()
//
// // Creates a FakeClientCallTracerFactory and adds it into the channel args.
// FakeClientCallTracerFactory fake_client_call_tracer_factory;
// ChannelArguments channel_args;
// channel_args.SetPointer(GRPC_ARG_INJECT_FAKE_CLIENT_CALL_TRACER_FACTORY,
// &fake_client_call_tracer_factory);
//
// // After the system under test has been executed (e.g. an RPC has been
// // sent), use the FakeClientCallTracerFactory to verify certain
// // expectations.
// EXPECT_THAT(fake_client_call_tracer_factory.GetLastFakeClientCallTracer()
// ->GetLastCallAttemptTracer()
// ->GetOptionalLabels(),
// VerifyCsmServiceLabels());
void RegisterFakeStatsPlugin();
class FakeClientCallTracer : public ClientCallTracer {
public:
class FakeClientCallAttemptTracer
: public ClientCallTracer::CallAttemptTracer {
public:
explicit FakeClientCallAttemptTracer(
std::vector<std::string>* annotation_logger)
: annotation_logger_(annotation_logger) {}
~FakeClientCallAttemptTracer() override {}
void RecordSendInitialMetadata(
grpc_metadata_batch* /*send_initial_metadata*/) override {}
void RecordSendTrailingMetadata(
grpc_metadata_batch* /*send_trailing_metadata*/) override {}
void RecordSendMessage(const SliceBuffer& /*send_message*/) override {}
void RecordSendCompressedMessage(
const SliceBuffer& /*send_compressed_message*/) override {}
void RecordReceivedInitialMetadata(
grpc_metadata_batch* /*recv_initial_metadata*/) override {}
void RecordReceivedMessage(const SliceBuffer& /*recv_message*/) override {}
void RecordReceivedDecompressedMessage(
const SliceBuffer& /*recv_decompressed_message*/) override {}
void RecordCancel(grpc_error_handle /*cancel_error*/) override {}
void RecordReceivedTrailingMetadata(
absl::Status /*status*/,
grpc_metadata_batch* /*recv_trailing_metadata*/,
const grpc_transport_stream_stats* /*transport_stream_stats*/)
override {}
void RecordEnd(const gpr_timespec& /*latency*/) override {}
void RecordAnnotation(absl::string_view annotation) override {
annotation_logger_->push_back(std::string(annotation));
}
void RecordAnnotation(const Annotation& /*annotation*/) override {}
std::shared_ptr<TcpTracerInterface> StartNewTcpTrace() override {
return nullptr;
}
void AddOptionalLabels(
OptionalLabelComponent component,
std::shared_ptr<std::map<std::string, std::string>> labels) override {
optional_labels_.emplace(component, std::move(labels));
}
std::string TraceId() override { return ""; }
std::string SpanId() override { return ""; }
bool IsSampled() override { return false; }
const std::map<OptionalLabelComponent,
std::shared_ptr<std::map<std::string, std::string>>>&
GetOptionalLabels() const {
return optional_labels_;
}
private:
std::vector<std::string>* annotation_logger_;
std::map<OptionalLabelComponent,
std::shared_ptr<std::map<std::string, std::string>>>
optional_labels_;
};
explicit FakeClientCallTracer(std::vector<std::string>* annotation_logger)
: annotation_logger_(annotation_logger) {}
~FakeClientCallTracer() override {}
CallAttemptTracer* StartNewAttempt(bool /*is_transparent_retry*/) override {
call_attempt_tracers_.emplace_back(
new FakeClientCallAttemptTracer(annotation_logger_));
return call_attempt_tracers_.back().get();
}
void RecordAnnotation(absl::string_view annotation) override {
annotation_logger_->push_back(std::string(annotation));
}
void RecordAnnotation(const Annotation& /*annotation*/) override {}
std::string TraceId() override { return ""; }
std::string SpanId() override { return ""; }
bool IsSampled() override { return false; }
FakeClientCallAttemptTracer* GetLastCallAttemptTracer() const {
return call_attempt_tracers_.back().get();
}
private:
std::vector<std::string>* annotation_logger_;
std::vector<std::unique_ptr<FakeClientCallAttemptTracer>>
call_attempt_tracers_;
};
#define GRPC_ARG_INJECT_FAKE_CLIENT_CALL_TRACER_FACTORY \
"grpc.testing.inject_fake_client_call_tracer_factory"
class FakeClientCallTracerFactory {
public:
FakeClientCallTracer* CreateFakeClientCallTracer() {
fake_client_call_tracers_.emplace_back(
new FakeClientCallTracer(&annotation_logger_));
return fake_client_call_tracers_.back().get();
}
FakeClientCallTracer* GetLastFakeClientCallTracer() {
return fake_client_call_tracers_.back().get();
}
private:
std::vector<std::string> annotation_logger_;
std::vector<std::unique_ptr<FakeClientCallTracer>> fake_client_call_tracers_;
};
class FakeServerCallTracer : public ServerCallTracer {
public:
explicit FakeServerCallTracer(std::vector<std::string>* annotation_logger)
: annotation_logger_(annotation_logger) {}
~FakeServerCallTracer() override {}
void RecordSendInitialMetadata(
grpc_metadata_batch* /*send_initial_metadata*/) override {}
void RecordSendTrailingMetadata(
grpc_metadata_batch* /*send_trailing_metadata*/) override {}
void RecordSendMessage(const SliceBuffer& /*send_message*/) override {}
void RecordSendCompressedMessage(
const SliceBuffer& /*send_compressed_message*/) override {}
void RecordReceivedInitialMetadata(
grpc_metadata_batch* /*recv_initial_metadata*/) override {}
void RecordReceivedMessage(const SliceBuffer& /*recv_message*/) override {}
void RecordReceivedDecompressedMessage(
const SliceBuffer& /*recv_decompressed_message*/) override {}
void RecordCancel(grpc_error_handle /*cancel_error*/) override {}
void RecordReceivedTrailingMetadata(
grpc_metadata_batch* /*recv_trailing_metadata*/) override {}
void RecordEnd(const grpc_call_final_info* /*final_info*/) override {}
void RecordAnnotation(absl::string_view annotation) override {
annotation_logger_->push_back(std::string(annotation));
}
void RecordAnnotation(const Annotation& /*annotation*/) override {}
std::shared_ptr<TcpTracerInterface> StartNewTcpTrace() override {
return nullptr;
}
std::string TraceId() override { return ""; }
std::string SpanId() override { return ""; }
bool IsSampled() override { return false; }
private:
std::vector<std::string>* annotation_logger_;
};
} // namespace grpc_core
#endif // GRPC_TEST_CORE_UTIL_FAKE_STATS_PLUGIN_H

@ -20,6 +20,7 @@
#include <google/protobuf/any.pb.h>
#include <google/protobuf/duration.pb.h>
#include <google/protobuf/struct.pb.h>
#include <google/protobuf/wrappers.pb.h>
#include "absl/status/status.h"
@ -1613,6 +1614,95 @@ TEST_F(HostOverrideStatusTest, CanExplicitlySetToEmpty) {
EXPECT_EQ(resource.override_host_statuses.ToString(), "{}");
}
using TelemetryLabelTest = XdsClusterTest;
TEST_F(TelemetryLabelTest, ValidServiceLabelsConfig) {
Cluster cluster;
cluster.set_type(cluster.EDS);
cluster.mutable_eds_cluster_config()->mutable_eds_config()->mutable_self();
auto& filter_map = *cluster.mutable_metadata()->mutable_filter_metadata();
auto& label_map =
*filter_map["com.google.csm.telemetry_labels"].mutable_fields();
*label_map["service_name"].mutable_string_value() = "abc";
*label_map["service_namespace"].mutable_string_value() = "xyz";
std::string serialized_resource;
ASSERT_TRUE(cluster.SerializeToString(&serialized_resource));
auto* resource_type = XdsClusterResourceType::Get();
auto decode_result =
resource_type->Decode(decode_context_, serialized_resource);
ASSERT_TRUE(decode_result.resource.ok()) << decode_result.resource.status();
auto& resource =
static_cast<const XdsClusterResource&>(**decode_result.resource);
EXPECT_THAT(*resource.telemetry_labels,
::testing::UnorderedElementsAre(
::testing::Pair("service_name", "abc"),
::testing::Pair("service_namespace", "xyz")));
}
TEST_F(TelemetryLabelTest, MissingMetadataField) {
Cluster cluster;
cluster.set_type(cluster.EDS);
cluster.mutable_eds_cluster_config()->mutable_eds_config()->mutable_self();
std::string serialized_resource;
ASSERT_TRUE(cluster.SerializeToString(&serialized_resource));
auto* resource_type = XdsClusterResourceType::Get();
auto decode_result =
resource_type->Decode(decode_context_, serialized_resource);
ASSERT_TRUE(decode_result.resource.ok()) << decode_result.resource.status();
auto& resource =
static_cast<const XdsClusterResource&>(**decode_result.resource);
EXPECT_EQ(resource.telemetry_labels, nullptr);
}
TEST_F(TelemetryLabelTest, MissingCsmFilterMetadataField) {
Cluster cluster;
cluster.set_type(cluster.EDS);
cluster.mutable_eds_cluster_config()->mutable_eds_config()->mutable_self();
auto& filter_map = *cluster.mutable_metadata()->mutable_filter_metadata();
auto& label_map = *filter_map["some_key"].mutable_fields();
*label_map["some_value"].mutable_string_value() = "abc";
std::string serialized_resource;
ASSERT_TRUE(cluster.SerializeToString(&serialized_resource));
auto* resource_type = XdsClusterResourceType::Get();
auto decode_result =
resource_type->Decode(decode_context_, serialized_resource);
ASSERT_TRUE(decode_result.resource.ok()) << decode_result.resource.status();
auto& resource =
static_cast<const XdsClusterResource&>(**decode_result.resource);
EXPECT_EQ(resource.telemetry_labels, nullptr);
}
TEST_F(TelemetryLabelTest, IgnoreNonStringEntries) {
Cluster cluster;
cluster.set_type(cluster.EDS);
cluster.mutable_eds_cluster_config()->mutable_eds_config()->mutable_self();
auto& filter_map = *cluster.mutable_metadata()->mutable_filter_metadata();
auto& label_map =
*filter_map["com.google.csm.telemetry_labels"].mutable_fields();
label_map["bool_value"].set_bool_value(true);
label_map["number_value"].set_number_value(3.14);
*label_map["string_value"].mutable_string_value() = "abc";
label_map["null_value"].set_null_value(::google::protobuf::NULL_VALUE);
auto& list_value_values =
*label_map["list_value"].mutable_list_value()->mutable_values();
*list_value_values.Add()->mutable_string_value() = "efg";
list_value_values.Add()->set_number_value(3.14);
auto& struct_value_fields =
*label_map["struct_value"].mutable_struct_value()->mutable_fields();
struct_value_fields["bool_value"].set_bool_value(false);
std::string serialized_resource;
ASSERT_TRUE(cluster.SerializeToString(&serialized_resource));
auto* resource_type = XdsClusterResourceType::Get();
auto decode_result =
resource_type->Decode(decode_context_, serialized_resource);
ASSERT_TRUE(decode_result.resource.ok()) << decode_result.resource.status();
auto& resource =
static_cast<const XdsClusterResource&>(**decode_result.resource);
EXPECT_THAT(
*resource.telemetry_labels,
::testing::UnorderedElementsAre(::testing::Pair("string_value", "abc")));
}
} // namespace
} // namespace testing
} // namespace grpc_core

@ -169,6 +169,7 @@ grpc_cc_test(
"//:gpr",
"//:grpc",
"//:grpc++",
"//test/core/util:fake_stats_plugin",
"//test/core/util:grpc_test_util",
"//test/core/util:scoped_env_var",
"//test/cpp/end2end:connection_attempt_injector",

@ -25,9 +25,12 @@
#include "src/core/ext/filters/client_channel/backup_poller.h"
#include "src/core/lib/address_utils/sockaddr_utils.h"
#include "src/core/lib/channel/call_tracer.h"
#include "src/core/lib/config/config_vars.h"
#include "src/core/lib/experiments/experiments.h"
#include "src/core/lib/surface/call.h"
#include "src/proto/grpc/testing/xds/v3/orca_load_report.pb.h"
#include "test/core/util/fake_stats_plugin.h"
#include "test/core/util/scoped_env_var.h"
#include "test/cpp/end2end/connection_attempt_injector.h"
#include "test/cpp/end2end/xds/xds_end2end_test_lib.h"
@ -42,6 +45,8 @@ using ::envoy::config::core::v3::HealthStatus;
using ::envoy::type::v3::FractionalPercent;
using ClientStats = LrsServiceImpl::ClientStats;
using OptionalLabelComponent =
grpc_core::ClientCallTracer::CallAttemptTracer::OptionalLabelComponent;
constexpr char kLbDropType[] = "lb";
constexpr char kThrottleDropType[] = "throttle";
@ -304,6 +309,39 @@ TEST_P(CdsTest, ClusterChangeAfterAdsCallFails) {
WaitForBackend(DEBUG_LOCATION, 1);
}
TEST_P(CdsTest, VerifyCsmServiceLabelsParsing) {
// Injects a fake client call tracer factory. Try keep this at top.
grpc_core::FakeClientCallTracerFactory fake_client_call_tracer_factory;
CreateAndStartBackends(1);
// Populates EDS resources.
EdsResourceArgs args({{"locality0", CreateEndpointsForBackends()}});
balancer_->ads_service()->SetEdsResource(BuildEdsResource(args));
// Populates service labels to CDS resources.
auto cluster = default_cluster_;
auto& filter_map = *cluster.mutable_metadata()->mutable_filter_metadata();
auto& label_map =
*filter_map["com.google.csm.telemetry_labels"].mutable_fields();
*label_map["service_name"].mutable_string_value() = "myservice";
*label_map["service_namespace"].mutable_string_value() = "mynamespace";
balancer_->ads_service()->SetCdsResource(cluster);
ChannelArguments channel_args;
channel_args.SetPointer(GRPC_ARG_INJECT_FAKE_CLIENT_CALL_TRACER_FACTORY,
&fake_client_call_tracer_factory);
ResetStub(/*failover_timeout_ms=*/0, &channel_args);
// Sends an RPC and verifies that the service labels are recorded in the fake
// client call tracer.
CheckRpcSendOk(DEBUG_LOCATION);
EXPECT_THAT(fake_client_call_tracer_factory.GetLastFakeClientCallTracer()
->GetLastCallAttemptTracer()
->GetOptionalLabels(),
::testing::ElementsAre(::testing::Pair(
OptionalLabelComponent::kXdsServiceLabels,
::testing::Pointee(::testing::ElementsAre(
::testing::Pair("service_name", "myservice"),
::testing::Pair("service_namespace", "mynamespace"))))));
balancer_->Shutdown();
}
//
// CDS deletion tests
//
@ -1874,6 +1912,7 @@ int main(int argc, char** argv) {
// Workaround Apple CFStream bug
grpc_core::SetEnv("grpc_cfstream", "0");
#endif
grpc_core::RegisterFakeStatsPlugin();
grpc_init();
grpc::testing::ConnectionAttemptInjector::Init();
const auto result = RUN_ALL_TESTS();

@ -113,7 +113,8 @@ class MetadataExchangeTest
public ::testing::WithParamInterface<TestScenario> {
protected:
void Init(const absl::flat_hash_set<absl::string_view>& metric_names,
bool enable_client_side_injector = true) {
bool enable_client_side_injector = true,
const std::map<std::string, std::string>& labels_to_inject = {}) {
const char* kBootstrap =
"{\"node\": {\"id\": "
"\"projects/1234567890/networks/mesh:mesh-id/nodes/"
@ -137,7 +138,7 @@ class MetadataExchangeTest
/*labels_injector=*/
std::make_unique<grpc::internal::ServiceMeshLabelsInjector>(
GetParam().GetTestResource().GetAttributes()),
/*test_no_meter_provider=*/false,
/*test_no_meter_provider=*/false, labels_to_inject,
/*target_selector=*/
[enable_client_side_injector](absl::string_view /*target*/) {
return enable_client_side_injector;
@ -156,11 +157,19 @@ class MetadataExchangeTest
void VerifyServiceMeshAttributes(
const std::map<std::string,
opentelemetry::sdk::common::OwnedAttributeValue>&
attributes) {
attributes,
bool verify_client_only_attributes = true) {
EXPECT_EQ(
absl::get<std::string>(attributes.at("csm.workload_canonical_service")),
"canonical_service");
EXPECT_EQ(absl::get<std::string>(attributes.at("csm.mesh_id")), "mesh-id");
if (verify_client_only_attributes) {
EXPECT_EQ(absl::get<std::string>(attributes.at("csm.service_name")),
"unknown");
EXPECT_EQ(
absl::get<std::string>(attributes.at("csm.service_namespace_name")),
"unknown");
}
switch (GetParam().type()) {
case TestScenario::ResourceType::kGke:
EXPECT_EQ(
@ -295,7 +304,8 @@ TEST_P(MetadataExchangeTest, ServerCallDuration) {
const auto& attributes = data[kMetricName][0].attributes.GetAttributes();
EXPECT_EQ(absl::get<std::string>(attributes.at("grpc.method")), kMethodName);
EXPECT_EQ(absl::get<std::string>(attributes.at("grpc.status")), "OK");
VerifyServiceMeshAttributes(attributes);
VerifyServiceMeshAttributes(attributes,
/*verify_client_only_attributes=*/false);
}
// Test that the server records unknown when the client does not send metadata
@ -328,6 +338,27 @@ TEST_P(MetadataExchangeTest, ClientDoesNotSendMetadata) {
"unknown");
}
TEST_P(MetadataExchangeTest, VerifyCsmServiceLabels) {
Init(/*metric_names=*/{grpc::OpenTelemetryPluginBuilder::
kClientAttemptDurationInstrumentName},
/*enable_client_side_injector=*/true,
// Injects CSM service labels to be recorded in the call.
{{"service_name", "myservice"}, {"service_namespace", "mynamespace"}});
SendRPC();
const char* kMetricName = "grpc.client.attempt.duration";
auto data = ReadCurrentMetricsData(
[&](const absl::flat_hash_map<
std::string,
std::vector<opentelemetry::sdk::metrics::PointDataAttributes>>&
data) { return !data.contains(kMetricName); });
ASSERT_EQ(data[kMetricName].size(), 1);
const auto& attributes = data[kMetricName][0].attributes.GetAttributes();
EXPECT_EQ(absl::get<std::string>(attributes.at("csm.service_name")),
"myservice");
EXPECT_EQ(absl::get<std::string>(attributes.at("csm.service_namespace_name")),
"mynamespace");
}
INSTANTIATE_TEST_SUITE_P(
MetadataExchange, MetadataExchangeTest,
::testing::Values(

@ -309,6 +309,7 @@ TEST_F(OpenTelemetryPluginEnd2EndTest, TargetSelectorReturnsTrue) {
opentelemetry::sdk::resource::Resource::Create({}),
/*labels_injector=*/nullptr,
/*test_no_meter_provider=*/false,
/*labels_to_inject=*/{},
/*target_selector=*/
[](absl::string_view /*target*/) { return true; });
SendRPC();
@ -345,6 +346,7 @@ TEST_F(OpenTelemetryPluginEnd2EndTest, TargetSelectorReturnsFalse) {
opentelemetry::sdk::resource::Resource::Create({}),
/*labels_injector=*/nullptr,
/*test_no_meter_provider=*/false,
/*labels_to_inject=*/{},
/*target_selector=*/
[](absl::string_view /*target*/) { return false; });
SendRPC();
@ -364,6 +366,7 @@ TEST_F(OpenTelemetryPluginEnd2EndTest, TargetAttributeFilterReturnsTrue) {
opentelemetry::sdk::resource::Resource::Create({}),
/*labels_injector=*/nullptr,
/*test_no_meter_provider=*/false,
/*labels_to_inject=*/{},
/*target_selector=*/absl::AnyInvocable<bool(absl::string_view) const>(),
/*target_attribute_filter=*/[](absl::string_view /*target*/) {
return true;
@ -402,6 +405,7 @@ TEST_F(OpenTelemetryPluginEnd2EndTest, TargetAttributeFilterReturnsFalse) {
opentelemetry::sdk::resource::Resource::Create({}),
/*labels_injector=*/nullptr,
/*test_no_meter_provider=*/false,
/*labels_to_inject=*/{},
/*target_selector=*/absl::AnyInvocable<bool(absl::string_view) const>(),
/*target_attribute_filter=*/
[server_address = canonical_server_address_](
@ -469,6 +473,7 @@ TEST_F(OpenTelemetryPluginEnd2EndTest,
/*resource=*/opentelemetry::sdk::resource::Resource::Create({}),
/*labels_injector=*/nullptr,
/*test_no_meter_provider=*/false,
/*labels_to_inject=*/{},
/*target_selector=*/absl::AnyInvocable<bool(absl::string_view) const>(),
/*target_attribute_filter=*/
absl::AnyInvocable<bool(absl::string_view) const>(),
@ -508,6 +513,7 @@ TEST_F(OpenTelemetryPluginEnd2EndTest,
/*resource=*/opentelemetry::sdk::resource::Resource::Create({}),
/*labels_injector=*/nullptr,
/*test_no_meter_provider=*/false,
/*labels_to_inject=*/{},
/*target_selector=*/absl::AnyInvocable<bool(absl::string_view) const>(),
/*target_attribute_filter=*/
absl::AnyInvocable<bool(absl::string_view) const>(),
@ -575,6 +581,7 @@ TEST_F(OpenTelemetryPluginEnd2EndTest,
/*resource=*/opentelemetry::sdk::resource::Resource::Create({}),
/*labels_injector=*/nullptr,
/*test_no_meter_provider=*/false,
/*labels_to_inject=*/{},
/*target_selector=*/absl::AnyInvocable<bool(absl::string_view) const>(),
/*target_attribute_filter=*/
absl::AnyInvocable<bool(absl::string_view) const>(),
@ -613,6 +620,7 @@ TEST_F(OpenTelemetryPluginEnd2EndTest,
/*resource=*/opentelemetry::sdk::resource::Resource::Create({}),
/*labels_injector=*/nullptr,
/*test_no_meter_provider=*/false,
/*labels_to_inject=*/{},
/*target_selector=*/absl::AnyInvocable<bool(absl::string_view) const>(),
/*target_attribute_filter=*/
absl::AnyInvocable<bool(absl::string_view) const>(),
@ -685,6 +693,22 @@ class CustomLabelInjector : public grpc::internal::LabelsInjector {
grpc::internal::LabelsIterable* /*labels_from_incoming_metadata*/)
const override {}
bool AddOptionalLabels(
absl::Span<const std::shared_ptr<std::map<std::string, std::string>>>
/*optional_labels_span*/,
opentelemetry::nostd::function_ref<
bool(opentelemetry::nostd::string_view,
opentelemetry::common::AttributeValue)>
/*callback*/) const override {
return true;
}
size_t GetOptionalLabelsSize(
absl::Span<const std::shared_ptr<std::map<std::string, std::string>>>
/*optional_labels_span*/) const override {
return 0;
}
private:
std::pair<std::string, std::string> label_;
};
@ -731,6 +755,7 @@ TEST_F(OpenTelemetryPluginOptionEnd2EndTest, Basic) {
/*resource=*/opentelemetry::sdk::resource::Resource::Create({}),
/*labels_injector=*/nullptr,
/*test_no_meter_provider=*/false,
/*labels_to_inject=*/{},
/*target_selector=*/absl::AnyInvocable<bool(absl::string_view) const>(),
/*target_attribute_filter=*/
absl::AnyInvocable<bool(absl::string_view) const>(),
@ -772,6 +797,7 @@ TEST_F(OpenTelemetryPluginOptionEnd2EndTest, ClientOnlyPluginOption) {
/*resource=*/opentelemetry::sdk::resource::Resource::Create({}),
/*labels_injector=*/nullptr,
/*test_no_meter_provider=*/false,
/*labels_to_inject=*/{},
/*target_selector=*/absl::AnyInvocable<bool(absl::string_view) const>(),
/*target_attribute_filter=*/
absl::AnyInvocable<bool(absl::string_view) const>(),
@ -814,6 +840,7 @@ TEST_F(OpenTelemetryPluginOptionEnd2EndTest, ServerOnlyPluginOption) {
/*resource=*/opentelemetry::sdk::resource::Resource::Create({}),
/*labels_injector=*/nullptr,
/*test_no_meter_provider=*/false,
/*labels_to_inject=*/{},
/*target_selector=*/absl::AnyInvocable<bool(absl::string_view) const>(),
/*target_attribute_filter=*/
absl::AnyInvocable<bool(absl::string_view) const>(),
@ -870,6 +897,7 @@ TEST_F(OpenTelemetryPluginOptionEnd2EndTest,
/*resource=*/opentelemetry::sdk::resource::Resource::Create({}),
/*labels_injector=*/nullptr,
/*test_no_meter_provider=*/false,
/*labels_to_inject=*/{},
/*target_selector=*/absl::AnyInvocable<bool(absl::string_view) const>(),
/*target_attribute_filter=*/
absl::AnyInvocable<bool(absl::string_view) const>(),

@ -29,6 +29,7 @@
#include <grpcpp/grpcpp.h>
#include "src/core/lib/channel/call_tracer.h"
#include "src/core/lib/channel/promise_based_filter.h"
#include "src/core/lib/config/core_configuration.h"
#include "src/core/lib/gprpp/notification.h"
#include "test/core/util/test_config.h"
@ -38,11 +39,55 @@
namespace grpc {
namespace testing {
#define GRPC_ARG_LABELS_TO_INJECT "grpc.testing.labels_to_inject"
// A subchannel filter that adds the service labels for test to the
// CallAttemptTracer in a call.
class AddServiceLabelsFilter : public grpc_core::ChannelFilter {
public:
static const grpc_channel_filter kFilter;
static absl::StatusOr<AddServiceLabelsFilter> Create(
const grpc_core::ChannelArgs& args, ChannelFilter::Args /*filter_args*/) {
return AddServiceLabelsFilter(
args.GetPointer<const std::map<std::string, std::string>>(
GRPC_ARG_LABELS_TO_INJECT));
}
grpc_core::ArenaPromise<grpc_core::ServerMetadataHandle> MakeCallPromise(
grpc_core::CallArgs call_args,
grpc_core::NextPromiseFactory next_promise_factory) override {
using CallAttemptTracer = grpc_core::ClientCallTracer::CallAttemptTracer;
auto* call_context = grpc_core::GetContext<grpc_call_context_element>();
auto* call_tracer = static_cast<CallAttemptTracer*>(
call_context[GRPC_CONTEXT_CALL_TRACER].value);
EXPECT_NE(call_tracer, nullptr);
call_tracer->AddOptionalLabels(
CallAttemptTracer::OptionalLabelComponent::kXdsServiceLabels,
std::make_shared<std::map<std::string, std::string>>(
*labels_to_inject_));
return next_promise_factory(std::move(call_args));
}
private:
explicit AddServiceLabelsFilter(
const std::map<std::string, std::string>* labels_to_inject)
: labels_to_inject_(labels_to_inject) {}
const std::map<std::string, std::string>* labels_to_inject_;
};
const grpc_channel_filter AddServiceLabelsFilter::kFilter =
grpc_core::MakePromiseBasedFilter<AddServiceLabelsFilter,
grpc_core::FilterEndpoint::kClient>(
"add_service_labels_filter");
void OpenTelemetryPluginEnd2EndTest::Init(
const absl::flat_hash_set<absl::string_view>& metric_names,
opentelemetry::sdk::resource::Resource resource,
std::unique_ptr<grpc::internal::LabelsInjector> labels_injector,
bool test_no_meter_provider,
const std::map<std::string, std::string>& labels_to_inject,
absl::AnyInvocable<bool(absl::string_view /*target*/) const>
target_selector,
absl::AnyInvocable<bool(absl::string_view /*target*/) const>
@ -83,6 +128,16 @@ void OpenTelemetryPluginEnd2EndTest::Init(
ot_builder.AddPluginOption(std::move(option));
}
ot_builder.BuildAndRegisterGlobal();
ChannelArguments channel_args;
if (!labels_to_inject.empty()) {
labels_to_inject_ = labels_to_inject;
grpc_core::CoreConfiguration::RegisterBuilder(
[](grpc_core::CoreConfiguration::Builder* builder) mutable {
builder->channel_init()->RegisterFilter(
GRPC_CLIENT_SUBCHANNEL, &AddServiceLabelsFilter::kFilter);
});
channel_args.SetPointer(GRPC_ARG_LABELS_TO_INJECT, &labels_to_inject_);
}
grpc_init();
grpc::ServerBuilder builder;
int port;
@ -96,8 +151,8 @@ void OpenTelemetryPluginEnd2EndTest::Init(
server_address_ = absl::StrCat("localhost:", port);
canonical_server_address_ = absl::StrCat("dns:///", server_address_);
auto channel =
grpc::CreateChannel(server_address_, grpc::InsecureChannelCredentials());
auto channel = grpc::CreateCustomChannel(
server_address_, grpc::InsecureChannelCredentials(), channel_args);
stub_ = EchoTestService::NewStub(channel);
generic_stub_ = std::make_unique<GenericStub>(std::move(channel));
}

@ -63,6 +63,7 @@ class OpenTelemetryPluginEnd2EndTest : public ::testing::Test {
opentelemetry::sdk::resource::Resource::Create({}),
std::unique_ptr<grpc::internal::LabelsInjector> labels_injector = nullptr,
bool test_no_meter_provider = false,
const std::map<std::string, std::string>& labels_to_inject = {},
absl::AnyInvocable<bool(absl::string_view /*target*/) const>
target_selector = absl::AnyInvocable<bool(absl::string_view) const>(),
absl::AnyInvocable<bool(absl::string_view /*target*/) const>
@ -94,6 +95,7 @@ class OpenTelemetryPluginEnd2EndTest : public ::testing::Test {
const absl::string_view kMethodName = "grpc.testing.EchoTestService/Echo";
const absl::string_view kGenericMethodName = "foo/bar";
std::map<std::string, std::string> labels_to_inject_;
std::shared_ptr<opentelemetry::sdk::metrics::MetricReader> reader_;
std::string server_address_;
std::string canonical_server_address_;

@ -29,14 +29,11 @@ DIRS=(
'src/python/grpcio_testing/grpc_testing'
'src/python/grpcio_status/grpc_status'
'src/python/grpcio_observability/grpc_observability'
'tools/run_tests/xds_k8s_test_driver/bin'
'tools/run_tests/xds_k8s_test_driver/framework'
)
TEST_DIRS=(
'src/python/grpcio_tests/tests'
'src/python/grpcio_tests/tests_gevent'
'tools/run_tests/xds_k8s_test_driver/tests'
)
VIRTUALENV=python_pylint_venv

@ -79,4 +79,4 @@ fi
# commit so that changes are passed to Docker
git -c user.name='foo' -c user.email='foo@google.com' commit -a -m 'Update submodule' --allow-empty
tools/run_tests/run_tests_matrix.py -f linux --exclude c sanity basictests_arm64 --inner_jobs 16 -j 2 --internal_ci --build_only
tools/run_tests/run_tests_matrix.py -f linux --exclude c sanity basictests_arm64 openssl --inner_jobs 16 -j 2 --internal_ci --build_only

@ -284,7 +284,7 @@ test_driver_compile_protos() {
#######################################
# Installs the test driver and it's requirements.
# https://github.com/grpc/grpc/tree/master/tools/run_tests/xds_k8s_test_driver#installation
# https://github.com/grpc/psm-interop#installation
# Globals:
# TEST_DRIVER_REPO_DIR: Populated with the path to the repo containing
# the test driver

@ -99,7 +99,7 @@ build_docker_images_if_needed() {
#######################################
run_test() {
# Test driver usage:
# https://github.com/grpc/grpc/tree/master/tools/run_tests/xds_k8s_test_driver#basic-usage
# https://github.com/grpc/psm-interop#basic-usage
local test_name="${1:?Usage: run_test test_name}"
local out_dir="${TEST_XML_OUTPUT_DIR}/${test_name}"
mkdir -pv "${out_dir}"

@ -112,7 +112,7 @@ build_docker_images_if_needed() {
#######################################
run_test() {
# Test driver usage:
# https://github.com/grpc/grpc/tree/master/tools/run_tests/xds_k8s_test_driver#basic-usage
# https://github.com/grpc/psm-interop#basic-usage
local test_name="${1:?Usage: run_test test_name}"
local out_dir="${TEST_XML_OUTPUT_DIR}/${test_name}"
mkdir -pv "${out_dir}"

@ -69,7 +69,7 @@ run_test() {
exit 1
fi
# Test driver usage:
# https://github.com/grpc/grpc/tree/master/tools/run_tests/xds_k8s_test_driver#basic-usage
# https://github.com/grpc/psm-interop#basic-usage
local client_lang="$1"
local client_branch="$2"
local server_lang="$3"

@ -87,7 +87,7 @@ build_docker_images_if_needed() {
#######################################
run_test() {
# Test driver usage:
# https://github.com/grpc/grpc/tree/master/tools/run_tests/xds_k8s_test_driver#basic-usage
# https://github.com/grpc/psm-interop#basic-usage
local test_name="${1:?Usage: run_test test_name}"
local out_dir="${TEST_XML_OUTPUT_DIR}/${test_name}"
mkdir -pv "${out_dir}"

@ -97,7 +97,7 @@ build_docker_images_if_needed() {
#######################################
run_test() {
# Test driver usage:
# https://github.com/grpc/grpc/tree/master/tools/run_tests/xds_k8s_test_driver#basic-usage
# https://github.com/grpc/psm-interop#basic-usage
local test_name="${1:?Usage: run_test test_name}"
local out_dir="${TEST_XML_OUTPUT_DIR}/${test_name}"
mkdir -pv "${out_dir}"

@ -99,7 +99,7 @@ build_docker_images_if_needed() {
#######################################
run_test() {
# Test driver usage:
# https://github.com/grpc/grpc/tree/master/tools/run_tests/xds_k8s_test_driver#basic-usage
# https://github.com/grpc/psm-interop#basic-usage
local test_name="${1:?Usage: run_test test_name}"
local out_dir="${TEST_XML_OUTPUT_DIR}/${test_name}"
mkdir -pv "${out_dir}"

@ -114,7 +114,7 @@ build_docker_images_if_needed() {
#######################################
run_test() {
# Test driver usage:
# https://github.com/grpc/grpc/tree/master/tools/run_tests/xds_k8s_test_driver#basic-usage
# https://github.com/grpc/psm-interop#basic-usage
local test_name="${1:?Usage: run_test test_name}"
local out_dir="${TEST_XML_OUTPUT_DIR}/${test_name}"
mkdir -pv "${out_dir}"

@ -99,7 +99,7 @@ build_docker_images_if_needed() {
#######################################
run_test() {
# Test driver usage:
# https://github.com/grpc/grpc/tree/master/tools/run_tests/xds_k8s_test_driver#basic-usage
# https://github.com/grpc/psm-interop#basic-usage
local test_name="${1:?Usage: run_test test_name}"
local out_dir="${TEST_XML_OUTPUT_DIR}/${test_name}"
mkdir -pv "${out_dir}"

@ -370,7 +370,8 @@ def _create_portability_test_jobs(
platforms=["linux"],
arch="x64",
compiler=compiler,
labels=["portability", "corelang"],
labels=["portability", "corelang"]
+ (["openssl"] if "openssl" in compiler else []),
extra_args=extra_args,
inner_jobs=inner_jobs,
timeout_seconds=_CPP_RUNTESTS_TIMEOUT,

@ -1,5 +0,0 @@
config/local-*.cfg
src/proto
venv/
venv-*/
out/

@ -1,458 +1,3 @@
# xDS Kubernetes Interop Tests
Proxyless Security Mesh Interop Tests executed on Kubernetes.
### Experimental
Work in progress. Internal APIs may and will change. Please refrain from making
changes to this codebase at the moment.
### Stabilization roadmap
- [x] Replace retrying with tenacity
- [x] Generate namespace for each test to prevent resource name conflicts and
allow running tests in parallel
- [x] Security: run server and client in separate namespaces
- [ ] Make framework.infrastructure.gcp resources [first-class
citizen](https://en.wikipedia.org/wiki/First-class_citizen), support
simpler CRUD
- [x] Security: manage `roles/iam.workloadIdentityUser` role grant lifecycle for
dynamically-named namespaces
- [x] Restructure `framework.test_app` and `framework.xds_k8s*` into a module
containing xDS-interop-specific logic
- [ ] Address inline TODOs in code
- [x] Improve README.md documentation, explain helpers in bin/ folder
## Installation
#### Requirements
1. Python v3.9+
2. [Google Cloud SDK](https://cloud.google.com/sdk/docs/install)
3. `kubectl`
`kubectl` can be installed via `gcloud components install kubectl`, or system package manager: https://kubernetes.io/docs/tasks/tools/#kubectl
Python3 venv tool may need to be installed from APT on some Ubuntu systems:
```shell
sudo apt-get install python3-venv
```
##### Getting Started
1. If you haven't, [initialize](https://cloud.google.com/sdk/docs/install-sdk) gcloud SDK
2. Activate gcloud [configuration](https://cloud.google.com/sdk/docs/configurations) with your project
3. Enable gcloud services:
```shell
gcloud services enable \
compute.googleapis.com \
container.googleapis.com \
logging.googleapis.com \
monitoring.googleapis.com \
networksecurity.googleapis.com \
networkservices.googleapis.com \
secretmanager.googleapis.com \
trafficdirector.googleapis.com
```
#### Configure GKE cluster
This is an example outlining minimal requirements to run the [baseline tests](#xds-baseline-tests).
Update gloud sdk:
```shell
gcloud -q components update
```
Pre-populate environment variables for convenience. To find project id, refer to
[Identifying projects](https://cloud.google.com/resource-manager/docs/creating-managing-projects#identifying_projects).
```shell
export PROJECT_ID="your-project-id"
export PROJECT_NUMBER=$(gcloud projects describe "${PROJECT_ID}" --format="value(projectNumber)")
# Compute Engine default service account
export GCE_SA="${PROJECT_NUMBER}-compute@developer.gserviceaccount.com"
# The prefix to name GCP resources used by the framework
export RESOURCE_PREFIX="xds-k8s-interop-tests"
# The zone name your cluster, f.e. xds-k8s-test-cluster
export CLUSTER_NAME="${RESOURCE_PREFIX}-cluster"
# The zone of your cluster, f.e. us-central1-a
export ZONE="us-central1-a"
# Dedicated GCP Service Account to use with workload identity.
export WORKLOAD_SA_NAME="${RESOURCE_PREFIX}"
export WORKLOAD_SA_EMAIL="${WORKLOAD_SA_NAME}@${PROJECT_ID}.iam.gserviceaccount.com"
```
##### Create the cluster
Minimal requirements: [VPC-native](https://cloud.google.com/traffic-director/docs/security-proxyless-setup)
cluster with [Workload Identity](https://cloud.google.com/kubernetes-engine/docs/how-to/workload-identity) enabled.
```shell
gcloud container clusters create "${CLUSTER_NAME}" \
--scopes=cloud-platform \
--zone="${ZONE}" \
--enable-ip-alias \
--workload-pool="${PROJECT_ID}.svc.id.goog" \
--workload-metadata=GKE_METADATA \
--tags=allow-health-checks
```
For security tests you also need to create CAs and configure the cluster to use those CAs
as described
[here](https://cloud.google.com/traffic-director/docs/security-proxyless-setup#configure-cas).
##### Create the firewall rule
Allow [health checking mechanisms](https://cloud.google.com/traffic-director/docs/set-up-proxyless-gke#creating_the_health_check_firewall_rule_and_backend_service)
to query the workloads health.
This step can be skipped, if the driver is executed with `--ensure_firewall`.
```shell
gcloud compute firewall-rules create "${RESOURCE_PREFIX}-allow-health-checks" \
--network=default --action=allow --direction=INGRESS \
--source-ranges="35.191.0.0/16,130.211.0.0/22" \
--target-tags=allow-health-checks \
--rules=tcp:8080-8100
```
##### Setup GCP Service Account
Create dedicated GCP Service Account to use
with [workload identity](https://cloud.google.com/kubernetes-engine/docs/how-to/workload-identity).
```shell
gcloud iam service-accounts create "${WORKLOAD_SA_NAME}" \
--display-name="xDS K8S Interop Tests Workload Identity Service Account"
```
Enable the service account to [access the Traffic Director API](https://cloud.google.com/traffic-director/docs/prepare-for-envoy-setup#enable-service-account).
```shell
gcloud projects add-iam-policy-binding "${PROJECT_ID}" \
--member="serviceAccount:${WORKLOAD_SA_EMAIL}" \
--role="roles/trafficdirector.client"
```
##### Allow access to images
The test framework needs read access to the client and server images and the bootstrap
generator image. You may have these images in your project but if you want to use these
from the grpc-testing project you will have to grant the necessary access to these images
using https://cloud.google.com/container-registry/docs/access-control#grant or a
gsutil command. For example, to grant access to images stored in `grpc-testing` project GCR, run:
```sh
gsutil iam ch "serviceAccount:${GCE_SA}:objectViewer" gs://artifacts.grpc-testing.appspot.com/
```
##### Allow test driver to configure workload identity automatically
Test driver will automatically grant `roles/iam.workloadIdentityUser` to
allow the Kubernetes service account to impersonate the dedicated GCP workload
service account (corresponds to the step 5
of [Authenticating to Google Cloud](https://cloud.google.com/kubernetes-engine/docs/how-to/workload-identity#authenticating_to)).
This action requires the test framework to have `iam.serviceAccounts.create`
permission on the project.
If you're running test framework locally, and you have `roles/owner` to your
project, **you can skip this step**.
If you're configuring the test framework to run on a CI: use `roles/owner`
account once to allow test framework to grant `roles/iam.workloadIdentityUser`.
```shell
# Assuming CI is using Compute Engine default service account.
gcloud projects add-iam-policy-binding "${PROJECT_ID}" \
--member="serviceAccount:${GCE_SA}" \
--role="roles/iam.serviceAccountAdmin" \
--condition-from-file=<(cat <<-END
---
title: allow_workload_identity_only
description: Restrict serviceAccountAdmin to granting role iam.workloadIdentityUser
expression: |-
api.getAttribute('iam.googleapis.com/modifiedGrantsByRole', [])
.hasOnly(['roles/iam.workloadIdentityUser'])
END
)
```
##### Configure GKE cluster access
```shell
# Unless you're using GCP VM with preconfigured Application Default Credentials, acquire them for your user
gcloud auth application-default login
# Install authentication plugin for kubectl.
# Details: https://cloud.google.com/blog/products/containers-kubernetes/kubectl-auth-changes-in-gke
gcloud components install gke-gcloud-auth-plugin
# Configuring GKE cluster access for kubectl
gcloud container clusters get-credentials "${CLUSTER_NAME}" --zone "${ZONE}"
# Save generated kube context name
export KUBE_CONTEXT="$(kubectl config current-context)"
```
#### Install python dependencies
```shell
# Create python virtual environment
python3 -m venv venv
# Activate virtual environment
. ./venv/bin/activate
# Install requirements
pip install -r requirements.lock
# Generate protos
python -m grpc_tools.protoc --proto_path=../../../ \
--python_out=. --grpc_python_out=. \
src/proto/grpc/testing/empty.proto \
src/proto/grpc/testing/messages.proto \
src/proto/grpc/testing/test.proto
```
# Basic usage
## Local development
This test driver allows running tests locally against remote GKE clusters, right
from your dev environment. You need:
1. Follow [installation](#installation) instructions
2. Authenticated `gcloud`
3. `kubectl` context (see [Configure GKE cluster access](#configure-gke-cluster-access))
4. Run tests with `--debug_use_port_forwarding` argument. The test driver
will automatically start and stop port forwarding using
`kubectl` subprocesses. (experimental)
### Making changes to the driver
1. Install additional dev packages: `pip install -r requirements-dev.txt`
2. Use `./bin/black.sh` and `./bin/isort.sh` helpers to auto-format code.
### Updating Python Dependencies
We track our Python-level dependencies using three different files:
- `requirements.txt`
- `dev-requirements.txt`
- `requirements.lock`
`requirements.txt` lists modules without specific versions supplied, though
versions ranges may be specified. `requirements.lock` is generated from
`requirements.txt` and _does_ specify versions for every dependency in the
transitive dependency tree.
When updating `requirements.txt`, you must also update `requirements.lock`. To
do this, navigate to this directory and run `./bin/freeze.sh`.
### Setup test configuration
There are many arguments to be passed into the test run. You can save the
arguments to a config file ("flagfile") for your development environment.
Use [`config/local-dev.cfg.example`](https://github.com/grpc/grpc/blob/master/tools/run_tests/xds_k8s_test_driver/config/local-dev.cfg.example)
as a starting point:
```shell
cp config/local-dev.cfg.example config/local-dev.cfg
```
If you exported environment variables in the above sections, you can
template them into the local config (note this recreates the config):
```shell
envsubst < config/local-dev.cfg.example > config/local-dev.cfg
```
Learn more about flagfiles in [abseil documentation](https://abseil.io/docs/python/guides/flags#a-note-about---flagfile).
## Test suites
See the full list of available test suites in the [`tests/`](https://github.com/grpc/grpc/tree/master/tools/run_tests/xds_k8s_test_driver/tests) folder.
### xDS Baseline Tests
Test suite meant to confirm that basic xDS features work as expected. Executing
it before other test suites will help to identify whether test failure related
to specific features under test, or caused by unrelated infrastructure
disturbances.
```shell
# Help
python -m tests.baseline_test --help
python -m tests.baseline_test --helpfull
# Run the baseline test with local-dev.cfg settings
python -m tests.baseline_test --flagfile="config/local-dev.cfg"
# Same as above, but using the helper script
./run.sh tests/baseline_test.py
```
### xDS Security Tests
Test suite meant to verify mTLS/TLS features. Note that this requires
additional environment configuration. For more details, and for the
setup for the security tests, see
["Setting up Traffic Director service security with proxyless gRPC"](https://cloud.google.com/traffic-director/docs/security-proxyless-setup) user guide.
```shell
# Run the security test with local-dev.cfg settings
python -m tests.security_test --flagfile="config/local-dev.cfg"
# Same as above, but using the helper script
./run.sh tests/security_test.py
```
## Helper scripts
You can use interop xds-k8s [`bin/`](https://github.com/grpc/grpc/tree/master/tools/run_tests/xds_k8s_test_driver/bin)
scripts to configure TD, start k8s instances step-by-step, and keep them alive
for as long as you need.
* To run helper scripts using local config:
* `python -m bin.script_name --flagfile=config/local-dev.cfg`
* `./run.sh bin/script_name.py` automatically appends the flagfile
* Use `--help` to see script-specific argument
* Use `--helpfull` to see all available argument
#### Overview
```shell
# Helper tool to configure Traffic Director with different security options
python -m bin.run_td_setup --help
# Helper tools to run the test server, client (with or without security)
python -m bin.run_test_server --help
python -m bin.run_test_client --help
# Helper tool to verify different security configurations via channelz
python -m bin.run_channelz --help
```
#### `./run.sh` helper
Use `./run.sh` to execute helper scripts and tests with `config/local-dev.cfg`.
```sh
USAGE: ./run.sh script_path [arguments]
script_path: path to python script to execute, relative to driver root folder
arguments ...: arguments passed to program in sys.argv
ENVIRONMENT:
XDS_K8S_CONFIG: file path to the config flagfile, relative to
driver root folder. Default: config/local-dev.cfg
Will be appended as --flagfile="config_absolute_path" argument
XDS_K8S_DRIVER_VENV_DIR: the path to python virtual environment directory
Default: $XDS_K8S_DRIVER_DIR/venv
DESCRIPTION:
This tool performs the following:
1) Ensures python virtual env installed and activated
2) Exports test driver root in PYTHONPATH
3) Automatically appends --flagfile="\$XDS_K8S_CONFIG" argument
EXAMPLES:
./run.sh bin/run_td_setup.py --help
./run.sh bin/run_td_setup.py --helpfull
XDS_K8S_CONFIG=./path-to-flagfile.cfg ./run.sh bin/run_td_setup.py --resource_suffix=override-suffix
./run.sh tests/baseline_test.py
./run.sh tests/security_test.py --verbosity=1 --logger_levels=__main__:DEBUG,framework:DEBUG
./run.sh tests/security_test.py SecurityTest.test_mtls --nocheck_local_certs
```
## Partial setups
### Regular workflow
```shell
# Setup Traffic Director
./run.sh bin/run_td_setup.py
# Start test server
./run.sh bin/run_test_server.py
# Add test server to the backend service
./run.sh bin/run_td_setup.py --cmd=backends-add
# Start test client
./run.sh bin/run_test_client.py
```
### Secure workflow
```shell
# Setup Traffic Director in mtls. See --help for all options
./run.sh bin/run_td_setup.py --security=mtls
# Start test server in a secure mode
./run.sh bin/run_test_server.py --mode=secure
# Add test server to the backend service
./run.sh bin/run_td_setup.py --cmd=backends-add
# Start test client in a secure more --mode=secure
./run.sh bin/run_test_client.py --mode=secure
```
### Sending RPCs
#### Start port forwarding
```shell
# Client: all services always on port 8079
kubectl port-forward deployment.apps/psm-grpc-client 8079
# Server regular mode: all grpc services on port 8080
kubectl port-forward deployment.apps/psm-grpc-server 8080
# OR
# Server secure mode: TestServiceImpl is on 8080,
kubectl port-forward deployment.apps/psm-grpc-server 8080
# everything else (channelz, healthcheck, CSDS) on 8081
kubectl port-forward deployment.apps/psm-grpc-server 8081
```
#### Send RPCs with grpccurl
```shell
# 8081 if security enabled
export SERVER_ADMIN_PORT=8080
# List server services using reflection
grpcurl --plaintext 127.0.0.1:$SERVER_ADMIN_PORT list
# List client services using reflection
grpcurl --plaintext 127.0.0.1:8079 list
# List channels via channelz
grpcurl --plaintext 127.0.0.1:$SERVER_ADMIN_PORT grpc.channelz.v1.Channelz.GetTopChannels
grpcurl --plaintext 127.0.0.1:8079 grpc.channelz.v1.Channelz.GetTopChannels
# Send GetClientStats to the client
grpcurl --plaintext -d '{"num_rpcs": 10, "timeout_sec": 30}' 127.0.0.1:8079 \
grpc.testing.LoadBalancerStatsService.GetClientStats
```
### Cleanup
* First, make sure to stop port forwarding, if any
* Run `./bin/cleanup.sh`
##### Partial cleanup
You can run commands below to stop/start, create/delete resources however you want.
Generally, it's better to remove resources in the opposite order of their creation.
Cleanup regular resources:
```shell
# Cleanup TD resources
./run.sh bin/run_td_setup.py --cmd=cleanup
# Stop test client
./run.sh bin/run_test_client.py --cmd=cleanup
# Stop test server, and remove the namespace
./run.sh bin/run_test_server.py --cmd=cleanup --cleanup_namespace
```
Cleanup regular and security-specific resources:
```shell
# Cleanup TD resources, with security
./run.sh bin/run_td_setup.py --cmd=cleanup --security=mtls
# Stop test client (secure)
./run.sh bin/run_test_client.py --cmd=cleanup --mode=secure
# Stop test server (secure), and remove the namespace
./run.sh bin/run_test_server.py --cmd=cleanup --cleanup_namespace --mode=secure
```
In addition, here's some other helpful partial cleanup commands:
```shell
# Remove all backends from the backend services
./run.sh bin/run_td_setup.py --cmd=backends-cleanup
# Stop the server, but keep the namespace
./run.sh bin/run_test_server.py --cmd=cleanup --nocleanup_namespace
```
### Known errors
#### Error forwarding port
If you stopped a test with `ctrl+c`, while using `--debug_use_port_forwarding`,
you might see an error like this:
> `framework.infrastructure.k8s.PortForwardingError: Error forwarding port, unexpected output Unable to listen on port 8081: Listeners failed to create with the following errors: [unable to create listener: Error listen tcp4 127.0.0.1:8081: bind: address already in use]`
Unless you're running `kubectl port-forward` manually, it's likely that `ctrl+c`
interrupted python before it could clean up subprocesses.
You can do `ps aux | grep port-forward` and then kill the processes by id,
or with `killall kubectl`
The source is migrated to https://github.com/grpc/psm-interop.

@ -1,13 +0,0 @@
# Copyright 2020 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -1,60 +0,0 @@
#!/usr/bin/env bash
# Copyright 2021 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -eo pipefail
display_usage() {
cat <<EOF >/dev/stderr
A helper to run black formatter.
USAGE: $0 [--diff]
--diff: Do not apply changes, only show the diff
--check: Do not apply changes, only print what files will be changed
ENVIRONMENT:
XDS_K8S_DRIVER_VENV_DIR: the path to python virtual environment directory
Default: $XDS_K8S_DRIVER_DIR/venv
EXAMPLES:
$0
$0 --diff
$0 --check
EOF
exit 1
}
if [[ "$1" == "-h" || "$1" == "--help" ]]; then
display_usage
fi
SCRIPT_DIR="$( cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )"
readonly SCRIPT_DIR
readonly XDS_K8S_DRIVER_DIR="${SCRIPT_DIR}/.."
cd "${XDS_K8S_DRIVER_DIR}"
# Relative paths not yet supported by shellcheck.
# shellcheck source=/dev/null
source "${XDS_K8S_DRIVER_DIR}/bin/ensure_venv.sh"
if [[ "$1" == "--diff" ]]; then
readonly MODE="--diff"
elif [[ "$1" == "--check" ]]; then
readonly MODE="--check"
else
readonly MODE=""
fi
# shellcheck disable=SC2086
exec python -m black --config=../../../black.toml ${MODE} .

@ -1,59 +0,0 @@
#!/usr/bin/env bash
# Copyright 2021 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -eo pipefail
display_usage() {
cat <<EOF >/dev/stderr
Performs full TD and K8S resource cleanup
USAGE: $0 [--nosecure] [arguments]
--nosecure: Skip cleanup for the resources specific for PSM Security
arguments ...: additional arguments passed to ./run.sh
ENVIRONMENT:
XDS_K8S_CONFIG: file path to the config flagfile, relative to
driver root folder. Default: config/local-dev.cfg
Will be appended as --flagfile="config_absolute_path" argument
XDS_K8S_DRIVER_VENV_DIR: the path to python virtual environment directory
Default: $XDS_K8S_DRIVER_DIR/venv
EXAMPLES:
$0
$0 --nosecure
XDS_K8S_CONFIG=./path-to-flagfile.cfg $0 --resource_suffix=override-suffix
EOF
exit 1
}
if [[ "$1" == "-h" || "$1" == "--help" ]]; then
display_usage
fi
SCRIPT_DIR="$( cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )"
readonly SCRIPT_DIR
readonly XDS_K8S_DRIVER_DIR="${SCRIPT_DIR}/.."
cd "${XDS_K8S_DRIVER_DIR}"
if [[ "$1" == "--nosecure" ]]; then
shift
./run.sh bin/run_td_setup.py --cmd=cleanup "$@" && \
./run.sh bin/run_test_client.py --cmd=cleanup --cleanup_namespace "$@" && \
./run.sh bin/run_test_server.py --cmd=cleanup --cleanup_namespace "$@"
else
./run.sh bin/run_td_setup.py --cmd=cleanup --security=mtls "$@" && \
./run.sh bin/run_test_client.py --cmd=cleanup --cleanup_namespace --mode=secure "$@" && \
./run.sh bin/run_test_server.py --cmd=cleanup --cleanup_namespace --mode=secure "$@"
fi

@ -1,2 +0,0 @@
# This folder contains scripts to delete leaked resources from test runs

@ -1,714 +0,0 @@
# Copyright 2021 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Clean up resources created by the tests.
This is intended as a tool to delete leaked resources from old tests.
Typical usage examples:
python3 -m bin.cleanup.cleanup \
--project=grpc-testing \
--network=default-vpc \
--kube_context=gke_grpc-testing_us-central1-a_psm-interop-security
"""
import dataclasses
import datetime
import functools
import json
import logging
import os
import re
import subprocess
import sys
from typing import Any, Callable, List, Optional
from absl import app
from absl import flags
import dateutil
from framework import xds_flags
from framework import xds_k8s_flags
from framework.helpers import retryers
from framework.infrastructure import gcp
from framework.infrastructure import k8s
from framework.infrastructure import traffic_director
from framework.test_app.runners.k8s import k8s_xds_client_runner
from framework.test_app.runners.k8s import k8s_xds_server_runner
logger = logging.getLogger(__name__)
Json = Any
_KubernetesClientRunner = k8s_xds_client_runner.KubernetesClientRunner
_KubernetesServerRunner = k8s_xds_server_runner.KubernetesServerRunner
GCLOUD = os.environ.get("GCLOUD", "gcloud")
GCLOUD_CMD_TIMEOUT_S = datetime.timedelta(seconds=5).total_seconds()
# Skip known k8s system namespaces.
K8S_PROTECTED_NAMESPACES = {
"default",
"gke-managed-system",
"kube-node-lease",
"kube-public",
"kube-system",
}
# TODO(sergiitk): these should be flags.
LEGACY_DRIVER_ZONE = "us-central1-a"
LEGACY_DRIVER_SECONDARY_ZONE = "us-west1-b"
PSM_INTEROP_PREFIX = "psm-interop" # Prefix for gke resources to delete.
URL_MAP_TEST_PREFIX = (
"interop-psm-url-map" # Prefix for url-map test resources to delete.
)
KEEP_PERIOD_HOURS = flags.DEFINE_integer(
"keep_hours",
default=48,
help=(
"number of hours for a resource to keep. Resources older than this will"
" be deleted. Default is 48 hours (2 days)"
),
)
DRY_RUN = flags.DEFINE_bool(
"dry_run",
default=False,
help="dry run, print resources but do not perform deletion",
)
TD_RESOURCE_PREFIXES = flags.DEFINE_list(
"td_resource_prefixes",
default=[PSM_INTEROP_PREFIX],
help=(
"a comma-separated list of prefixes for which the leaked TD resources"
" will be deleted"
),
)
SERVER_PREFIXES = flags.DEFINE_list(
"server_prefixes",
default=[PSM_INTEROP_PREFIX],
help=(
"a comma-separated list of prefixes for which the leaked servers will"
" be deleted"
),
)
CLIENT_PREFIXES = flags.DEFINE_list(
"client_prefixes",
default=[PSM_INTEROP_PREFIX, URL_MAP_TEST_PREFIX],
help=(
"a comma-separated list of prefixes for which the leaked clients will"
" be deleted"
),
)
MODE = flags.DEFINE_enum(
"mode",
default="td",
enum_values=["k8s", "td", "td_no_legacy"],
help="Mode: Kubernetes or Traffic Director",
)
SECONDARY = flags.DEFINE_bool(
"secondary",
default=False,
help="Cleanup secondary (alternative) resources",
)
# The cleanup script performs some API calls directly, so some flags normally
# required to configure framework properly, are not needed here.
flags.FLAGS.set_default("resource_prefix", "ignored-by-cleanup")
flags.FLAGS.set_default("td_bootstrap_image", "ignored-by-cleanup")
flags.FLAGS.set_default("server_image", "ignored-by-cleanup")
flags.FLAGS.set_default("client_image", "ignored-by-cleanup")
@dataclasses.dataclass(eq=False)
class CleanupResult:
error_count: int = 0
error_messages: List[str] = dataclasses.field(default_factory=list)
def add_error(self, msg: str):
self.error_count += 1
self.error_messages.append(f" {self.error_count}. {msg}")
def format_messages(self):
return "\n".join(self.error_messages)
@dataclasses.dataclass(frozen=True)
class K8sResourceRule:
# regex to match
expression: str
# function to delete the resource
cleanup_ns_fn: Callable
# Global state, holding the result of the whole operation.
_CLEANUP_RESULT = CleanupResult()
def load_keep_config() -> None:
global KEEP_CONFIG
json_path = os.path.realpath(
os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"keep_xds_interop_resources.json",
)
)
with open(json_path, "r") as f:
KEEP_CONFIG = json.load(f)
logging.debug(
"Resource keep config loaded: %s", json.dumps(KEEP_CONFIG, indent=2)
)
def is_marked_as_keep_gce(suffix: str) -> bool:
return suffix in KEEP_CONFIG["gce_framework"]["suffix"]
def is_marked_as_keep_gke(suffix: str) -> bool:
return suffix in KEEP_CONFIG["gke_framework"]["suffix"]
@functools.lru_cache()
def get_expire_timestamp() -> datetime.datetime:
return datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
hours=KEEP_PERIOD_HOURS.value
)
def exec_gcloud(project: str, *cmds: str) -> Json:
cmds = [GCLOUD, "--project", project, "--quiet"] + list(cmds)
if "list" in cmds:
# Add arguments to shape the list output
cmds.extend(
[
"--format",
"json",
"--filter",
f"creationTimestamp <= {get_expire_timestamp().isoformat()}",
]
)
# Executing the gcloud command
logging.debug("Executing: %s", " ".join(cmds))
proc = subprocess.Popen(
cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
# NOTE(lidiz) the gcloud subprocess won't return unless its output is read
stdout = proc.stdout.read()
stderr = proc.stderr.read()
try:
returncode = proc.wait(timeout=GCLOUD_CMD_TIMEOUT_S)
except subprocess.TimeoutExpired:
logging.error("> Timeout executing cmd [%s]", " ".join(cmds))
return None
if returncode:
logging.error(
"> Failed to execute cmd [%s], returned %d, stderr: %s",
" ".join(cmds),
returncode,
stderr,
)
return None
if stdout:
return json.loads(stdout)
return None
def cleanup_legacy_driver_resources(*, project: str, suffix: str, **kwargs):
"""Removing GCP resources created by run_xds_tests.py."""
# Unused, but kept for compatibility with cleanup_td_for_gke.
del kwargs
logging.info(
"----- Removing run_xds_tests.py resources with suffix [%s]", suffix
)
exec_gcloud(
project,
"compute",
"forwarding-rules",
"delete",
f"test-forwarding-rule{suffix}",
"--global",
)
exec_gcloud(
project,
"compute",
"target-http-proxies",
"delete",
f"test-target-proxy{suffix}",
)
exec_gcloud(
project,
"alpha",
"compute",
"target-grpc-proxies",
"delete",
f"test-target-proxy{suffix}",
)
exec_gcloud(project, "compute", "url-maps", "delete", f"test-map{suffix}")
exec_gcloud(
project,
"compute",
"backend-services",
"delete",
f"test-backend-service{suffix}",
"--global",
)
exec_gcloud(
project,
"compute",
"backend-services",
"delete",
f"test-backend-service-alternate{suffix}",
"--global",
)
exec_gcloud(
project,
"compute",
"backend-services",
"delete",
f"test-backend-service-extra{suffix}",
"--global",
)
exec_gcloud(
project,
"compute",
"backend-services",
"delete",
f"test-backend-service-more-extra{suffix}",
"--global",
)
exec_gcloud(
project, "compute", "firewall-rules", "delete", f"test-fw-rule{suffix}"
)
exec_gcloud(
project, "compute", "health-checks", "delete", f"test-hc{suffix}"
)
exec_gcloud(
project,
"compute",
"instance-groups",
"managed",
"delete",
f"test-ig{suffix}",
"--zone",
LEGACY_DRIVER_ZONE,
)
exec_gcloud(
project,
"compute",
"instance-groups",
"managed",
"delete",
f"test-ig-same-zone{suffix}",
"--zone",
LEGACY_DRIVER_ZONE,
)
exec_gcloud(
project,
"compute",
"instance-groups",
"managed",
"delete",
f"test-ig-secondary-zone{suffix}",
"--zone",
LEGACY_DRIVER_SECONDARY_ZONE,
)
exec_gcloud(
project,
"compute",
"instance-templates",
"delete",
f"test-template{suffix}",
)
# cleanup_td creates TrafficDirectorManager (and its varients for security and
# AppNet), and then calls the cleanup() methods.
#
# Note that the varients are all based on the basic TrafficDirectorManager, so
# their `cleanup()` might do duplicate work. But deleting an non-exist resource
# returns 404, and is OK.
def cleanup_td_for_gke(*, project, prefix, suffix, network):
gcp_api_manager = gcp.api.GcpApiManager()
plain_td = traffic_director.TrafficDirectorManager(
gcp_api_manager,
project=project,
network=network,
resource_prefix=prefix,
resource_suffix=suffix,
)
security_td = traffic_director.TrafficDirectorSecureManager(
gcp_api_manager,
project=project,
network=network,
resource_prefix=prefix,
resource_suffix=suffix,
)
# TODO: cleanup appnet resources.
# appnet_td = traffic_director.TrafficDirectorAppNetManager(
# gcp_api_manager,
# project=project,
# network=network,
# resource_prefix=resource_prefix,
# resource_suffix=resource_suffix)
logger.info(
"----- Removing traffic director for gke, prefix %s, suffix %s",
prefix,
suffix,
)
security_td.cleanup(force=True)
# appnet_td.cleanup(force=True)
plain_td.cleanup(force=True)
# cleanup_client creates a client runner, and calls its cleanup() method.
def cleanup_client(
project,
network,
k8s_api_manager,
client_namespace,
gcp_api_manager,
gcp_service_account,
*,
suffix: Optional[str] = "",
):
deployment_name = xds_flags.CLIENT_NAME.value
if suffix:
deployment_name = f"{deployment_name}-{suffix}"
ns = k8s.KubernetesNamespace(k8s_api_manager, client_namespace)
# Shorten the timeout to avoid waiting for the stuck namespaces.
# Normal ns deletion during the cleanup takes less two minutes.
ns.wait_for_namespace_deleted_timeout_sec = 5 * 60
client_runner = _KubernetesClientRunner(
k8s_namespace=ns,
deployment_name=deployment_name,
gcp_project=project,
network=network,
gcp_service_account=gcp_service_account,
gcp_api_manager=gcp_api_manager,
image_name="",
td_bootstrap_image="",
)
logger.info("Cleanup client")
try:
client_runner.cleanup(force=True, force_namespace=True)
except retryers.RetryError as err:
logger.error(
"Timeout waiting for namespace %s deletion. "
"Failed resource status:\n\n%s",
ns.name,
ns.pretty_format_status(err.result()),
)
raise
# cleanup_server creates a server runner, and calls its cleanup() method.
def cleanup_server(
project,
network,
k8s_api_manager,
server_namespace,
gcp_api_manager,
gcp_service_account,
*,
suffix: Optional[str] = "",
):
deployment_name = xds_flags.SERVER_NAME.value
if suffix:
deployment_name = f"{deployment_name}-{suffix}"
ns = k8s.KubernetesNamespace(k8s_api_manager, server_namespace)
# Shorten the timeout to avoid waiting for the stuck namespaces.
# Normal ns deletion during the cleanup takes less two minutes.
ns.wait_for_namespace_deleted_timeout_sec = 5 * 60
server_runner = _KubernetesServerRunner(
k8s_namespace=ns,
deployment_name=deployment_name,
gcp_project=project,
network=network,
gcp_service_account=gcp_service_account,
gcp_api_manager=gcp_api_manager,
image_name="",
td_bootstrap_image="",
)
logger.info("Cleanup server")
try:
server_runner.cleanup(force=True, force_namespace=True)
except retryers.RetryError as err:
logger.error(
"Timeout waiting for namespace %s deletion. "
"Failed resource status:\n\n%s",
ns.name,
ns.pretty_format_status(err.result()),
)
raise
def delete_leaked_td_resources(
dry_run, td_resource_rules, project, network, resources
):
for resource in resources:
logger.info("-----")
logger.info("----- Cleaning up resource %s", resource["name"])
if dry_run:
# Skip deletion for dry-runs
logging.info("----- Skipped [Dry Run]: %s", resource["name"])
continue
matched = False
for regex, resource_prefix, keep, remove_fn in td_resource_rules:
result = re.search(regex, resource["name"])
if result is not None:
matched = True
if keep(result.group(1)):
logging.info("Skipped [keep]:")
break # break inner loop, continue outer loop
remove_fn(
project=project,
prefix=resource_prefix,
suffix=result.group(1),
network=network,
)
break
if not matched:
logging.info(
"----- Skipped [does not matching resource name templates]"
)
def delete_k8s_resources(
dry_run,
k8s_resource_rules,
project,
network,
k8s_api_manager,
gcp_service_account,
namespaces,
):
gcp_api_manager = gcp.api.GcpApiManager()
for ns in namespaces:
namespace_name: str = ns.metadata.name
if namespace_name in K8S_PROTECTED_NAMESPACES:
continue
logger.info("-----")
logger.info("----- Cleaning up k8s namespaces %s", namespace_name)
if ns.metadata.creation_timestamp > get_expire_timestamp():
logging.info(
"----- Skipped [resource is within expiry date]: %s",
namespace_name,
)
continue
if dry_run:
# Skip deletion for dry-runs
logging.info("----- Skipped [Dry Run]: %s", ns.metadata.name)
continue
rule: K8sResourceRule = _rule_match_k8s_namespace(
namespace_name, k8s_resource_rules
)
if not rule:
logging.info(
"----- Skipped [does not matching resource name templates]: %s",
namespace_name,
)
continue
# Cleaning up.
try:
rule.cleanup_ns_fn(
project,
network,
k8s_api_manager,
namespace_name,
gcp_api_manager,
gcp_service_account,
suffix=("alt" if SECONDARY.value else None),
)
except k8s.NotFound:
logging.warning("----- Skipped [not found]: %s", namespace_name)
except retryers.RetryError as err:
_CLEANUP_RESULT.add_error(
"Retries exhausted while waiting for the "
f"deletion of namespace {namespace_name}: "
f"{err}"
)
logging.exception(
"----- Skipped [cleanup timed out]: %s", namespace_name
)
except Exception as err: # noqa pylint: disable=broad-except
_CLEANUP_RESULT.add_error(
"Unexpected error while deleting "
f"namespace {namespace_name}: {err}"
)
logging.exception(
"----- Skipped [cleanup unexpected error]: %s", namespace_name
)
logger.info("-----")
def _rule_match_k8s_namespace(
namespace_name: str, k8s_resource_rules: List[K8sResourceRule]
) -> Optional[K8sResourceRule]:
for rule in k8s_resource_rules:
result = re.search(rule.expression, namespace_name)
if result is not None:
return rule
return None
def find_and_remove_leaked_k8s_resources(
dry_run, project, network, gcp_service_account, k8s_context
):
k8s_resource_rules: List[K8sResourceRule] = []
for prefix in CLIENT_PREFIXES.value:
k8s_resource_rules.append(
K8sResourceRule(f"{prefix}-client-(.*)", cleanup_client)
)
for prefix in SERVER_PREFIXES.value:
k8s_resource_rules.append(
K8sResourceRule(f"{prefix}-server-(.*)", cleanup_server)
)
# Delete leaked k8s namespaces, those usually mean there are leaked testing
# client/servers from the gke framework.
k8s_api_manager = k8s.KubernetesApiManager(k8s_context)
nss = k8s_api_manager.core.list_namespace()
delete_k8s_resources(
dry_run,
k8s_resource_rules,
project,
network,
k8s_api_manager,
gcp_service_account,
nss.items,
)
def find_and_remove_leaked_td_resources(dry_run, project, network):
cleanup_legacy: bool = MODE.value != "td_no_legacy"
td_resource_rules = [
# itmes in each tuple, in order
# - regex to match
# - prefix of the resource (only used by gke resources)
# - function to check of the resource should be kept
# - function to delete the resource
]
if cleanup_legacy:
td_resource_rules += [
(
r"test-hc(.*)",
"",
is_marked_as_keep_gce,
cleanup_legacy_driver_resources,
),
(
r"test-template(.*)",
"",
is_marked_as_keep_gce,
cleanup_legacy_driver_resources,
),
]
for prefix in TD_RESOURCE_PREFIXES.value:
td_resource_rules.append(
(
f"{prefix}-health-check-(.*)",
prefix,
is_marked_as_keep_gke,
cleanup_td_for_gke,
),
)
# List resources older than KEEP_PERIOD. We only list health-checks and
# instance templates because these are leaves in the resource dependency
# tree.
#
# E.g. forwarding-rule depends on the target-proxy. So leaked
# forwarding-rule indicates there's a leaked target-proxy (because this
# target proxy cannot deleted unless the forwarding rule is deleted). The
# leaked target-proxy is guaranteed to be a super set of leaked
# forwarding-rule.
compute = gcp.compute.ComputeV1(gcp.api.GcpApiManager(), project)
leaked_health_checks = []
for item in compute.list_health_check()["items"]:
if (
dateutil.parser.isoparse(item["creationTimestamp"])
<= get_expire_timestamp()
):
leaked_health_checks.append(item)
delete_leaked_td_resources(
dry_run, td_resource_rules, project, network, leaked_health_checks
)
# Delete leaked instance templates, those usually mean there are leaked VMs
# from the gce framework. Also note that this is only needed for the gce
# resources.
if cleanup_legacy:
leaked_instance_templates = exec_gcloud(
project, "compute", "instance-templates", "list"
)
delete_leaked_td_resources(
dry_run,
td_resource_rules,
project,
network,
leaked_instance_templates,
)
def main(argv):
# TODO(sergiitk): instead, base on absltest so that result.xml is available.
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
load_keep_config()
# Must be called before KubernetesApiManager or GcpApiManager init.
xds_flags.set_socket_default_timeout_from_flag()
project: str = xds_flags.PROJECT.value
network: str = xds_flags.NETWORK.value
gcp_service_account: str = xds_k8s_flags.GCP_SERVICE_ACCOUNT.value
dry_run: bool = DRY_RUN.value
k8s_context: str = xds_k8s_flags.KUBE_CONTEXT.value
if MODE.value == "td" or MODE.value == "td_no_legacy":
find_and_remove_leaked_td_resources(dry_run, project, network)
elif MODE.value == "k8s":
# 'unset' value is used in td-only mode to bypass the validation
# for the required flag.
assert k8s_context != "unset"
find_and_remove_leaked_k8s_resources(
dry_run, project, network, gcp_service_account, k8s_context
)
logger.info("##################### Done cleaning up #####################")
if _CLEANUP_RESULT.error_count > 0:
logger.error(
"Cleanup failed for %i resource(s). Errors: [\n%s\n].\n"
"Please inspect the log files for stack traces corresponding "
"to these errors.",
_CLEANUP_RESULT.error_count,
_CLEANUP_RESULT.format_messages(),
)
sys.exit(1)
if __name__ == "__main__":
app.run(main)

@ -1,8 +0,0 @@
{
"gce_framework": {
"suffix": []
},
"gke_framework": {
"suffix": []
}
}

@ -1,95 +0,0 @@
#!/usr/bin/env bash
# Copyright 2023 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -eo pipefail
SCRIPT_DIR="$( cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )"
readonly SCRIPT_DIR
readonly XDS_K8S_DRIVER_DIR="${SCRIPT_DIR}/.."
cd "${XDS_K8S_DRIVER_DIR}"
NO_SECURE="yes"
DATE_TO=$(date -Iseconds)
while [[ $# -gt 0 ]]; do
case $1 in
--secure) NO_SECURE=""; shift ;;
--date_to=*) DATE_TO="${1#*=}T00:00:00Z"; shift ;;
*) echo "Unknown argument $1"; exit 1 ;;
esac
done
jq_selector=$(cat <<- 'EOM'
.items[].metadata |
select(
(.name | test("-(client|server)-")) and
(.creationTimestamp < $date_to)
) | .name
EOM
)
mapfile -t namespaces < <(\
kubectl get namespaces --sort-by='{.metadata.creationTimestamp}'\
--selector='owner=xds-k8s-interop-test'\
-o json\
| jq --arg date_to "${DATE_TO}" -r "${jq_selector}"
)
if [[ -z "${namespaces[*]}" ]]; then
echo "All clean."
exit 0
fi
echo "Found namespaces:"
namespaces_joined=$(IFS=,; printf '%s' "${namespaces[*]}")
kubectl get namespaces --sort-by='{.metadata.creationTimestamp}' \
--selector="name in (${namespaces_joined})"
# Suffixes
mapfile -t suffixes < <(\
printf '%s\n' "${namespaces[@]}" | sed -E 's/^.+-(server|client)-//'
)
echo
echo "Found suffixes: ${suffixes[*]}"
echo "Count: ${#namespaces[@]}"
echo "Run plan:"
for suffix in "${suffixes[@]}"; do
echo ./bin/cleanup.sh ${NO_SECURE:+"--nosecure"} "--resource_suffix=${suffix}"
done
read -r -n 1 -p "Continue? (y/N) " answer
if [[ "$answer" != "${answer#[Yy]}" ]] ;then
echo
echo "Starting the cleanup."
else
echo
echo "Exit"
exit 0
fi
failed=0
for suffix in "${suffixes[@]}"; do
echo "-------------------- Cleaning suffix ${suffix} --------------------"
set -x
./bin/cleanup.sh ${NO_SECURE:+"--nosecure"} "--resource_suffix=${suffix}" || (( ++failed ))
set +x
echo "-------------------- Finished cleaning ${suffix} --------------------"
done
echo "Failed runs: ${failed}"
if (( failed > 0 )); then
exit 1
fi

@ -1,29 +0,0 @@
#!/usr/bin/env bash
# Copyright 2021 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Expected $XDS_K8S_DRIVER_DIR to be set by the file sourcing this.
readonly XDS_K8S_DRIVER_VENV_DIR="${XDS_K8S_DRIVER_VENV_DIR:-$XDS_K8S_DRIVER_DIR/venv}"
if [[ -z "${VIRTUAL_ENV}" ]]; then
if [[ -d "${XDS_K8S_DRIVER_VENV_DIR}" ]]; then
# Intentional: No need to check python venv activate script.
# shellcheck source=/dev/null
source "${XDS_K8S_DRIVER_VENV_DIR}/bin/activate"
else
echo "Missing python virtual environment directory: ${XDS_K8S_DRIVER_VENV_DIR}" >&2
echo "Follow README.md installation steps first." >&2
exit 1
fi
fi

@ -1,28 +0,0 @@
#!/usr/bin/env bash
# Copyright 2023 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -exo pipefail
VENV_NAME="venv-$(mktemp -d)"
readonly VENV_NAME
python3 -m virtualenv "${VENV_NAME}"
"${VENV_NAME}"/bin/pip install -r requirements.txt
"${VENV_NAME}"/bin/pip freeze --require-virtualenv --local -r requirements.txt \
> requirements.lock
rm -rf "${VENV_NAME}"

@ -1,60 +0,0 @@
#!/usr/bin/env bash
# Copyright 2021 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -eo pipefail
display_usage() {
cat <<EOF >/dev/stderr
A helper to run isort import sorter.
USAGE: $0 [--diff]
--diff: Do not apply changes, only show the diff
ENVIRONMENT:
XDS_K8S_DRIVER_VENV_DIR: the path to python virtual environment directory
Default: $XDS_K8S_DRIVER_DIR/venv
EXAMPLES:
$0
$0 --diff
EOF
exit 1
}
if [[ "$1" == "-h" || "$1" == "--help" ]]; then
display_usage
fi
SCRIPT_DIR="$( cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )"
readonly SCRIPT_DIR
readonly XDS_K8S_DRIVER_DIR="${SCRIPT_DIR}/.."
cd "${XDS_K8S_DRIVER_DIR}"
# Relative paths not yet supported by shellcheck.
# shellcheck source=/dev/null
source "${XDS_K8S_DRIVER_DIR}/bin/ensure_venv.sh"
if [[ "$1" == "--diff" ]]; then
readonly MODE="--diff"
else
readonly MODE="--overwrite-in-place"
fi
# typing is the only module allowed to put imports on the same line:
# https://google.github.io/styleguide/pyguide.html#313-imports-formatting
exec python -m isort "${MODE}" \
--settings-path=../../../black.toml \
framework bin tests

@ -1,13 +0,0 @@
# Copyright 2023 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -1,191 +0,0 @@
# Copyright 2023 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common functionality for bin/ python helpers."""
import atexit
import signal
import sys
from absl import logging
from framework import xds_flags
from framework import xds_k8s_flags
from framework.infrastructure import gcp
from framework.infrastructure import k8s
from framework.test_app import client_app
from framework.test_app import server_app
from framework.test_app.runners.k8s import gamma_server_runner
from framework.test_app.runners.k8s import k8s_xds_client_runner
from framework.test_app.runners.k8s import k8s_xds_server_runner
logger = logging.get_absl_logger()
# Type aliases
KubernetesClientRunner = k8s_xds_client_runner.KubernetesClientRunner
KubernetesServerRunner = k8s_xds_server_runner.KubernetesServerRunner
GammaServerRunner = gamma_server_runner.GammaServerRunner
_XdsTestServer = server_app.XdsTestServer
_XdsTestClient = client_app.XdsTestClient
def make_client_namespace(
k8s_api_manager: k8s.KubernetesApiManager,
namespace_name: str = None,
) -> k8s.KubernetesNamespace:
if not namespace_name:
namespace_name: str = KubernetesClientRunner.make_namespace_name(
xds_flags.RESOURCE_PREFIX.value, xds_flags.RESOURCE_SUFFIX.value
)
return k8s.KubernetesNamespace(k8s_api_manager, namespace_name)
def make_client_runner(
namespace: k8s.KubernetesNamespace,
gcp_api_manager: gcp.api.GcpApiManager,
*,
port_forwarding: bool = False,
reuse_namespace: bool = True,
enable_workload_identity: bool = True,
mode: str = "default",
) -> KubernetesClientRunner:
# KubernetesClientRunner arguments.
runner_kwargs = dict(
deployment_name=xds_flags.CLIENT_NAME.value,
image_name=xds_k8s_flags.CLIENT_IMAGE.value,
td_bootstrap_image=xds_k8s_flags.TD_BOOTSTRAP_IMAGE.value,
gcp_project=xds_flags.PROJECT.value,
gcp_api_manager=gcp_api_manager,
gcp_service_account=xds_k8s_flags.GCP_SERVICE_ACCOUNT.value,
xds_server_uri=xds_flags.XDS_SERVER_URI.value,
network=xds_flags.NETWORK.value,
stats_port=xds_flags.CLIENT_PORT.value,
reuse_namespace=reuse_namespace,
debug_use_port_forwarding=port_forwarding,
enable_workload_identity=enable_workload_identity,
)
if mode == "secure":
runner_kwargs.update(
deployment_template="client-secure.deployment.yaml"
)
return KubernetesClientRunner(namespace, **runner_kwargs)
def make_server_namespace(
k8s_api_manager: k8s.KubernetesApiManager,
server_runner: KubernetesServerRunner = KubernetesServerRunner,
) -> k8s.KubernetesNamespace:
namespace_name: str = server_runner.make_namespace_name(
xds_flags.RESOURCE_PREFIX.value, xds_flags.RESOURCE_SUFFIX.value
)
return k8s.KubernetesNamespace(k8s_api_manager, namespace_name)
def make_server_runner(
namespace: k8s.KubernetesNamespace,
gcp_api_manager: gcp.api.GcpApiManager,
*,
port_forwarding: bool = False,
reuse_namespace: bool = True,
reuse_service: bool = False,
enable_workload_identity: bool = True,
mode: str = "default",
) -> KubernetesServerRunner:
# KubernetesServerRunner arguments.
runner_kwargs = dict(
deployment_name=xds_flags.SERVER_NAME.value,
image_name=xds_k8s_flags.SERVER_IMAGE.value,
td_bootstrap_image=xds_k8s_flags.TD_BOOTSTRAP_IMAGE.value,
xds_server_uri=xds_flags.XDS_SERVER_URI.value,
gcp_project=xds_flags.PROJECT.value,
gcp_api_manager=gcp_api_manager,
gcp_service_account=xds_k8s_flags.GCP_SERVICE_ACCOUNT.value,
network=xds_flags.NETWORK.value,
reuse_namespace=reuse_namespace,
reuse_service=reuse_service,
debug_use_port_forwarding=port_forwarding,
enable_workload_identity=enable_workload_identity,
)
server_runner = KubernetesServerRunner
if mode == "secure":
runner_kwargs["deployment_template"] = "server-secure.deployment.yaml"
elif mode == "gamma":
server_runner = GammaServerRunner
return server_runner(namespace, **runner_kwargs)
def _ensure_atexit(signum, frame):
"""Needed to handle signals or atexit handler won't be called."""
del frame
# Pylint is wrong about "Module 'signal' has no 'Signals' member":
# https://docs.python.org/3/library/signal.html#signal.Signals
sig = signal.Signals(signum) # pylint: disable=no-member
logger.warning("Caught %r, initiating graceful shutdown...\n", sig)
sys.exit(1)
def _graceful_exit(
server_runner: KubernetesServerRunner, client_runner: KubernetesClientRunner
):
"""Stop port forwarding processes."""
client_runner.stop_pod_dependencies()
server_runner.stop_pod_dependencies()
def register_graceful_exit(
server_runner: KubernetesServerRunner, client_runner: KubernetesClientRunner
):
atexit.register(_graceful_exit, server_runner, client_runner)
for signum in (signal.SIGTERM, signal.SIGHUP, signal.SIGINT):
signal.signal(signum, _ensure_atexit)
def get_client_pod(
client_runner: KubernetesClientRunner, deployment_name: str
) -> k8s.V1Pod:
client_deployment: k8s.V1Deployment
client_deployment = client_runner.k8s_namespace.get_deployment(
deployment_name
)
client_pod_name: str = client_runner._wait_deployment_pod_count(
client_deployment
)[0]
return client_runner._wait_pod_started(client_pod_name)
def get_server_pod(
server_runner: KubernetesServerRunner, deployment_name: str
) -> k8s.V1Pod:
server_deployment: k8s.V1Deployment
server_deployment = server_runner.k8s_namespace.get_deployment(
deployment_name
)
server_pod_name: str = server_runner._wait_deployment_pod_count(
server_deployment
)[0]
return server_runner._wait_pod_started(server_pod_name)
def get_test_server_for_pod(
server_runner: KubernetesServerRunner, server_pod: k8s.V1Pod, **kwargs
) -> _XdsTestServer:
return server_runner._xds_test_server_for_pod(server_pod, **kwargs)
def get_test_client_for_pod(
client_runner: KubernetesClientRunner, client_pod: k8s.V1Pod, **kwargs
) -> _XdsTestClient:
return client_runner._xds_test_client_for_pod(client_pod, **kwargs)

@ -1,271 +0,0 @@
# Copyright 2020 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Channelz debugging tool for xDS test client/server.
This is intended as a debugging / local development helper and not executed
as a part of interop test suites.
Typical usage examples:
# Show channel and server socket pair
python -m bin.run_channelz --flagfile=config/local-dev.cfg
# Evaluate setup for different security configurations
python -m bin.run_channelz --flagfile=config/local-dev.cfg --security=tls
python -m bin.run_channelz --flagfile=config/local-dev.cfg --security=mtls_error
# More information and usage options
python -m bin.run_channelz --helpfull
"""
import hashlib
from absl import app
from absl import flags
from absl import logging
from bin.lib import common
from framework import xds_flags
from framework import xds_k8s_flags
from framework.infrastructure import gcp
from framework.infrastructure import k8s
from framework.rpc import grpc_channelz
from framework.test_app import client_app
from framework.test_app import server_app
# Flags
_SECURITY = flags.DEFINE_enum(
"security",
default=None,
enum_values=[
"mtls",
"tls",
"plaintext",
"mtls_error",
"server_authz_error",
],
help="Show info for a security setup",
)
flags.adopt_module_key_flags(xds_flags)
flags.adopt_module_key_flags(xds_k8s_flags)
# Running outside of a test suite, so require explicit resource_suffix.
flags.mark_flag_as_required(xds_flags.RESOURCE_SUFFIX.name)
flags.register_validator(
xds_flags.SERVER_XDS_PORT.name,
lambda val: val > 0,
message=(
"Run outside of a test suite, must provide"
" the exact port value (must be greater than 0)."
),
)
logger = logging.get_absl_logger()
# Type aliases
_Channel = grpc_channelz.Channel
_Socket = grpc_channelz.Socket
_ChannelState = grpc_channelz.ChannelState
_XdsTestServer = server_app.XdsTestServer
_XdsTestClient = client_app.XdsTestClient
def debug_cert(cert):
if not cert:
return "<missing>"
sha1 = hashlib.sha1(cert)
return f"sha1={sha1.hexdigest()}, len={len(cert)}"
def debug_sock_tls(tls):
return (
f"local: {debug_cert(tls.local_certificate)}\n"
f"remote: {debug_cert(tls.remote_certificate)}"
)
def get_deployment_pods(k8s_ns, deployment_name):
deployment = k8s_ns.get_deployment(deployment_name)
return k8s_ns.list_deployment_pods(deployment)
def debug_security_setup_negative(test_client):
"""Debug negative cases: mTLS Error, Server AuthZ error
1) mTLS Error: Server expects client mTLS cert,
but client configured only for TLS.
2) AuthZ error: Client does not authorize server because of mismatched
SAN name.
"""
# Client side.
client_correct_setup = True
channel: _Channel = test_client.wait_for_server_channel_state(
state=_ChannelState.TRANSIENT_FAILURE
)
try:
subchannel, *subchannels = list(
test_client.channelz.list_channel_subchannels(channel)
)
except ValueError:
print(
"Client setup fail: subchannel not found. "
"Common causes: test client didn't connect to TD; "
"test client exhausted retries, and closed all subchannels."
)
return
# Client must have exactly one subchannel.
logger.debug("Found subchannel, %s", subchannel)
if subchannels:
client_correct_setup = False
print(f"Unexpected subchannels {subchannels}")
subchannel_state: _ChannelState = subchannel.data.state.state
if subchannel_state is not _ChannelState.TRANSIENT_FAILURE:
client_correct_setup = False
print(
"Subchannel expected to be in "
"TRANSIENT_FAILURE, same as its channel"
)
# Client subchannel must have no sockets.
sockets = list(test_client.channelz.list_subchannels_sockets(subchannel))
if sockets:
client_correct_setup = False
print(f"Unexpected subchannel sockets {sockets}")
# Results.
if client_correct_setup:
print(
"Client setup pass: the channel "
"to the server has exactly one subchannel "
"in TRANSIENT_FAILURE, and no sockets"
)
def debug_security_setup_positive(test_client, test_server):
"""Debug positive cases: mTLS, TLS, Plaintext."""
test_client.wait_for_server_channel_ready()
client_sock: _Socket = test_client.get_active_server_channel_socket()
server_sock: _Socket = test_server.get_server_socket_matching_client(
client_sock
)
server_tls = server_sock.security.tls
client_tls = client_sock.security.tls
print(f"\nServer certs:\n{debug_sock_tls(server_tls)}")
print(f"\nClient certs:\n{debug_sock_tls(client_tls)}")
print()
if server_tls.local_certificate:
eq = server_tls.local_certificate == client_tls.remote_certificate
print(f"(TLS) Server local matches client remote: {eq}")
else:
print("(TLS) Not detected")
if server_tls.remote_certificate:
eq = server_tls.remote_certificate == client_tls.local_certificate
print(f"(mTLS) Server remote matches client local: {eq}")
else:
print("(mTLS) Not detected")
def debug_basic_setup(test_client, test_server):
"""Show channel and server socket pair"""
test_client.wait_for_server_channel_ready()
client_sock: _Socket = test_client.get_active_server_channel_socket()
server_sock: _Socket = test_server.get_server_socket_matching_client(
client_sock
)
logger.debug("Client socket: %s\n", client_sock)
logger.debug("Matching server socket: %s\n", server_sock)
def main(argv):
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
# Must be called before KubernetesApiManager or GcpApiManager init.
xds_flags.set_socket_default_timeout_from_flag()
# Flags.
should_port_forward: bool = xds_k8s_flags.DEBUG_USE_PORT_FORWARDING.value
enable_workload_identity: bool = (
xds_k8s_flags.ENABLE_WORKLOAD_IDENTITY.value
)
is_secure: bool = bool(_SECURITY.value)
# Setup.
gcp_api_manager = gcp.api.GcpApiManager()
k8s_api_manager = k8s.KubernetesApiManager(xds_k8s_flags.KUBE_CONTEXT.value)
# Server.
server_namespace = common.make_server_namespace(k8s_api_manager)
server_runner = common.make_server_runner(
server_namespace,
gcp_api_manager,
port_forwarding=should_port_forward,
enable_workload_identity=enable_workload_identity,
mode="secure",
)
# Find server pod.
server_pod: k8s.V1Pod = common.get_server_pod(
server_runner, xds_flags.SERVER_NAME.value
)
# Client
client_namespace = common.make_client_namespace(k8s_api_manager)
client_runner = common.make_client_runner(
client_namespace,
gcp_api_manager,
port_forwarding=should_port_forward,
enable_workload_identity=enable_workload_identity,
mode="secure",
)
# Find client pod.
client_pod: k8s.V1Pod = common.get_client_pod(
client_runner, xds_flags.CLIENT_NAME.value
)
# Ensure port forwarding stopped.
common.register_graceful_exit(server_runner, client_runner)
# Create server app for the server pod.
test_server: _XdsTestServer = common.get_test_server_for_pod(
server_runner,
server_pod,
test_port=xds_flags.SERVER_PORT.value,
secure_mode=is_secure,
)
test_server.set_xds_address(
xds_flags.SERVER_XDS_HOST.value, xds_flags.SERVER_XDS_PORT.value
)
# Create client app for the client pod.
test_client: _XdsTestClient = common.get_test_client_for_pod(
client_runner, client_pod, server_target=test_server.xds_uri
)
with test_client, test_server:
if _SECURITY.value in ("mtls", "tls", "plaintext"):
debug_security_setup_positive(test_client, test_server)
elif _SECURITY.value in ("mtls_error", "server_authz_error"):
debug_security_setup_negative(test_client)
else:
debug_basic_setup(test_client, test_server)
logger.info("SUCCESS!")
if __name__ == "__main__":
app.run(main)

@ -1,169 +0,0 @@
# Copyright 2023 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl import app
from absl import flags
from absl import logging
from bin.lib import common
from framework import xds_flags
from framework import xds_k8s_flags
from framework.helpers import grpc as helpers_grpc
import framework.helpers.highlighter
from framework.infrastructure import gcp
from framework.infrastructure import k8s
from framework.rpc import grpc_channelz
from framework.rpc import grpc_testing
from framework.test_app import client_app
from framework.test_app import server_app
# Flags
_MODE = flags.DEFINE_enum(
"mode",
default="default",
enum_values=["default", "secure", "gamma"],
help="Select a deployment of the client/server",
)
_NUM_RPCS = flags.DEFINE_integer(
"num_rpcs",
default=100,
lower_bound=1,
upper_bound=10_000,
help="The number of RPCs to check.",
)
flags.adopt_module_key_flags(xds_flags)
flags.adopt_module_key_flags(xds_k8s_flags)
# Running outside of a test suite, so require explicit resource_suffix.
flags.mark_flag_as_required(xds_flags.RESOURCE_SUFFIX.name)
flags.register_validator(
xds_flags.SERVER_XDS_PORT.name,
lambda val: val > 0,
message=(
"Run outside of a test suite, must provide"
" the exact port value (must be greater than 0)."
),
)
logger = logging.get_absl_logger()
# Type aliases
_Channel = grpc_channelz.Channel
_Socket = grpc_channelz.Socket
_ChannelState = grpc_channelz.ChannelState
_XdsTestServer = server_app.XdsTestServer
_XdsTestClient = client_app.XdsTestClient
LoadBalancerStatsResponse = grpc_testing.LoadBalancerStatsResponse
def get_client_rpc_stats(
test_client: _XdsTestClient, num_rpcs: int
) -> LoadBalancerStatsResponse:
lb_stats = test_client.get_load_balancer_stats(num_rpcs=num_rpcs)
hl = framework.helpers.highlighter.HighlighterYaml()
logger.info(
"[%s] Received LoadBalancerStatsResponse:\n%s",
test_client.hostname,
hl.highlight(helpers_grpc.lb_stats_pretty(lb_stats)),
)
return lb_stats
def run_ping_pong(test_client: _XdsTestClient, num_rpcs: int):
test_client.wait_for_active_xds_channel()
test_client.wait_for_server_channel_ready()
lb_stats = get_client_rpc_stats(test_client, num_rpcs)
for backend, rpcs_count in lb_stats.rpcs_by_peer.items():
if int(rpcs_count) < 1:
raise AssertionError(
f"Backend {backend} did not receive a single RPC"
)
failed = int(lb_stats.num_failures)
if int(lb_stats.num_failures) > 0:
raise AssertionError(
f"Expected all RPCs to succeed: {failed} of {num_rpcs} failed"
)
def main(argv):
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
# Must be called before KubernetesApiManager or GcpApiManager init.
xds_flags.set_socket_default_timeout_from_flag()
# Flags.
should_port_forward: bool = xds_k8s_flags.DEBUG_USE_PORT_FORWARDING.value
enable_workload_identity: bool = (
xds_k8s_flags.ENABLE_WORKLOAD_IDENTITY.value
)
# Setup.
gcp_api_manager = gcp.api.GcpApiManager()
k8s_api_manager = k8s.KubernetesApiManager(xds_k8s_flags.KUBE_CONTEXT.value)
# Server.
server_namespace = common.make_server_namespace(k8s_api_manager)
server_runner = common.make_server_runner(
server_namespace,
gcp_api_manager,
port_forwarding=should_port_forward,
enable_workload_identity=enable_workload_identity,
mode=_MODE.value,
)
# Find server pod.
server_pod: k8s.V1Pod = common.get_server_pod(
server_runner, xds_flags.SERVER_NAME.value
)
# Client
client_namespace = common.make_client_namespace(k8s_api_manager)
client_runner = common.make_client_runner(
client_namespace,
gcp_api_manager,
port_forwarding=should_port_forward,
enable_workload_identity=enable_workload_identity,
mode=_MODE.value,
)
# Find client pod.
client_pod: k8s.V1Pod = common.get_client_pod(
client_runner, xds_flags.CLIENT_NAME.value
)
# Ensure port forwarding stopped.
common.register_graceful_exit(server_runner, client_runner)
# Create server app for the server pod.
test_server: _XdsTestServer = common.get_test_server_for_pod(
server_runner,
server_pod,
test_port=xds_flags.SERVER_PORT.value,
secure_mode=_MODE.value == "secure",
)
test_server.set_xds_address(
xds_flags.SERVER_XDS_HOST.value, xds_flags.SERVER_XDS_PORT.value
)
# Create client app for the client pod.
test_client: _XdsTestClient = common.get_test_client_for_pod(
client_runner, client_pod, server_target=test_server.xds_uri
)
with test_client, test_server:
run_ping_pong(test_client, _NUM_RPCS.value)
logger.info("SUCCESS!")
if __name__ == "__main__":
app.run(main)

@ -1,310 +0,0 @@
# Copyright 2020 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configure Traffic Director for different GRPC Proxyless.
This is intended as a debugging / local development helper and not executed
as a part of interop test suites.
Typical usage examples:
# Regular proxyless setup
python -m bin.run_td_setup --flagfile=config/local-dev.cfg
# Additional commands: cleanup, backend management, etc.
python -m bin.run_td_setup --flagfile=config/local-dev.cfg --cmd=cleanup
# PSM security setup options: mtls, tls, etc.
python -m bin.run_td_setup --flagfile=config/local-dev.cfg --security=mtls
# More information and usage options
python -m bin.run_td_setup --helpfull
"""
import logging
from absl import app
from absl import flags
from framework import xds_flags
from framework import xds_k8s_flags
from framework.helpers import rand
from framework.infrastructure import gcp
from framework.infrastructure import k8s
from framework.infrastructure import traffic_director
from framework.test_app.runners.k8s import k8s_xds_server_runner
logger = logging.getLogger(__name__)
# Flags
_CMD = flags.DEFINE_enum(
"cmd",
default="create",
enum_values=[
"cycle",
"create",
"cleanup",
"backends-add",
"backends-cleanup",
"unused-xds-port",
],
help="Command",
)
_SECURITY = flags.DEFINE_enum(
"security",
default=None,
enum_values=[
"mtls",
"tls",
"plaintext",
"mtls_error",
"server_authz_error",
],
help="Configure TD with security",
)
flags.adopt_module_key_flags(xds_flags)
flags.adopt_module_key_flags(xds_k8s_flags)
# Running outside of a test suite, so require explicit resource_suffix.
flags.mark_flag_as_required(xds_flags.RESOURCE_SUFFIX.name)
@flags.multi_flags_validator(
(xds_flags.SERVER_XDS_PORT.name, _CMD.name),
message=(
"Run outside of a test suite, must provide"
" the exact port value (must be greater than 0)."
),
)
def _check_server_xds_port_flag(flags_dict):
if flags_dict[_CMD.name] not in ("create", "cycle"):
return True
return flags_dict[xds_flags.SERVER_XDS_PORT.name] > 0
# Type aliases
_KubernetesServerRunner = k8s_xds_server_runner.KubernetesServerRunner
def main(
argv,
): # pylint: disable=too-many-locals,too-many-branches,too-many-statements
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
# Must be called before KubernetesApiManager or GcpApiManager init.
xds_flags.set_socket_default_timeout_from_flag()
command = _CMD.value
security_mode = _SECURITY.value
project: str = xds_flags.PROJECT.value
network: str = xds_flags.NETWORK.value
# Resource names.
resource_prefix: str = xds_flags.RESOURCE_PREFIX.value
resource_suffix: str = xds_flags.RESOURCE_SUFFIX.value
# Test server
server_name = xds_flags.SERVER_NAME.value
server_port = xds_flags.SERVER_PORT.value
server_maintenance_port = xds_flags.SERVER_MAINTENANCE_PORT.value
server_xds_host = xds_flags.SERVER_XDS_HOST.value
server_xds_port = xds_flags.SERVER_XDS_PORT.value
server_namespace = _KubernetesServerRunner.make_namespace_name(
resource_prefix, resource_suffix
)
gcp_api_manager = gcp.api.GcpApiManager()
if security_mode is None:
td = traffic_director.TrafficDirectorManager(
gcp_api_manager,
project=project,
network=network,
resource_prefix=resource_prefix,
resource_suffix=resource_suffix,
)
else:
td = traffic_director.TrafficDirectorSecureManager(
gcp_api_manager,
project=project,
network=network,
resource_prefix=resource_prefix,
resource_suffix=resource_suffix,
)
if server_maintenance_port is None:
server_maintenance_port = (
_KubernetesServerRunner.DEFAULT_SECURE_MODE_MAINTENANCE_PORT
)
try:
if command in ("create", "cycle"):
logger.info("Create mode")
if security_mode is None:
logger.info("No security")
td.setup_for_grpc(
server_xds_host,
server_xds_port,
health_check_port=server_maintenance_port,
)
elif security_mode == "mtls":
logger.info("Setting up mtls")
td.setup_for_grpc(
server_xds_host,
server_xds_port,
health_check_port=server_maintenance_port,
)
td.setup_server_security(
server_namespace=server_namespace,
server_name=server_name,
server_port=server_port,
tls=True,
mtls=True,
)
td.setup_client_security(
server_namespace=server_namespace,
server_name=server_name,
tls=True,
mtls=True,
)
elif security_mode == "tls":
logger.info("Setting up tls")
td.setup_for_grpc(
server_xds_host,
server_xds_port,
health_check_port=server_maintenance_port,
)
td.setup_server_security(
server_namespace=server_namespace,
server_name=server_name,
server_port=server_port,
tls=True,
mtls=False,
)
td.setup_client_security(
server_namespace=server_namespace,
server_name=server_name,
tls=True,
mtls=False,
)
elif security_mode == "plaintext":
logger.info("Setting up plaintext")
td.setup_for_grpc(
server_xds_host,
server_xds_port,
health_check_port=server_maintenance_port,
)
td.setup_server_security(
server_namespace=server_namespace,
server_name=server_name,
server_port=server_port,
tls=False,
mtls=False,
)
td.setup_client_security(
server_namespace=server_namespace,
server_name=server_name,
tls=False,
mtls=False,
)
elif security_mode == "mtls_error":
# Error case: server expects client mTLS cert,
# but client configured only for TLS
logger.info("Setting up mtls_error")
td.setup_for_grpc(
server_xds_host,
server_xds_port,
health_check_port=server_maintenance_port,
)
td.setup_server_security(
server_namespace=server_namespace,
server_name=server_name,
server_port=server_port,
tls=True,
mtls=True,
)
td.setup_client_security(
server_namespace=server_namespace,
server_name=server_name,
tls=True,
mtls=False,
)
elif security_mode == "server_authz_error":
# Error case: client does not authorize server
# because of mismatched SAN name.
logger.info("Setting up mtls_error")
td.setup_for_grpc(
server_xds_host,
server_xds_port,
health_check_port=server_maintenance_port,
)
# Regular TLS setup, but with client policy configured using
# intentionality incorrect server_namespace.
td.setup_server_security(
server_namespace=server_namespace,
server_name=server_name,
server_port=server_port,
tls=True,
mtls=False,
)
td.setup_client_security(
server_namespace=(
f"incorrect-namespace-{rand.rand_string()}"
),
server_name=server_name,
tls=True,
mtls=False,
)
logger.info("Works!")
except Exception: # noqa pylint: disable=broad-except
logger.exception("Got error during creation")
if command in ("cleanup", "cycle"):
logger.info("Cleaning up")
td.cleanup(force=True)
if command == "backends-add":
logger.info("Adding backends")
k8s_api_manager = k8s.KubernetesApiManager(
xds_k8s_flags.KUBE_CONTEXT.value
)
k8s_namespace = k8s.KubernetesNamespace(
k8s_api_manager, server_namespace
)
neg_name, neg_zones = k8s_namespace.parse_service_neg_status(
server_name, server_port
)
td.load_backend_service()
td.backend_service_add_neg_backends(neg_name, neg_zones)
td.wait_for_backends_healthy_status()
elif command == "backends-cleanup":
td.load_backend_service()
td.backend_service_remove_all_backends()
elif command == "unused-xds-port":
try:
unused_xds_port = td.find_unused_forwarding_rule_port()
logger.info(
"Found unused forwarding rule port: %s", unused_xds_port
)
except Exception: # noqa pylint: disable=broad-except
logger.exception("Couldn't find unused forwarding rule port")
if __name__ == "__main__":
app.run(main)

@ -1,162 +0,0 @@
# Copyright 2020 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Run test xds client.
Gamma example:
./run.sh bin/run_test_client.py --server_xds_host=psm-grpc-server \
--server_xds_port=80 \
--config_mesh=gketd-psm-grpc-server
"""
import logging
import signal
from absl import app
from absl import flags
from bin.lib import common
from framework import xds_flags
from framework import xds_k8s_flags
from framework.infrastructure import gcp
from framework.infrastructure import k8s
logger = logging.getLogger(__name__)
# Flags
_CMD = flags.DEFINE_enum(
"cmd", default="run", enum_values=["run", "cleanup"], help="Command"
)
_MODE = flags.DEFINE_enum(
"mode",
default="default",
enum_values=[
"default",
"secure",
# Uncomment if gamma-specific changes added to the client.
# "gamma",
],
help="Select client mode",
)
_QPS = flags.DEFINE_integer("qps", default=25, help="Queries per second")
_PRINT_RESPONSE = flags.DEFINE_bool(
"print_response", default=False, help="Client prints responses"
)
_FOLLOW = flags.DEFINE_bool(
"follow",
default=False,
help=(
"Follow pod logs. Requires --collect_app_logs or"
" --debug_use_port_forwarding"
),
)
_CONFIG_MESH = flags.DEFINE_string(
"config_mesh",
default=None,
help="Optional. Supplied to bootstrap generator to indicate AppNet mesh.",
)
_REUSE_NAMESPACE = flags.DEFINE_bool(
"reuse_namespace", default=True, help="Use existing namespace if exists"
)
_CLEANUP_NAMESPACE = flags.DEFINE_bool(
"cleanup_namespace",
default=False,
help="Delete namespace during resource cleanup",
)
flags.adopt_module_key_flags(xds_flags)
flags.adopt_module_key_flags(xds_k8s_flags)
# Running outside of a test suite, so require explicit resource_suffix.
flags.mark_flag_as_required(xds_flags.RESOURCE_SUFFIX.name)
@flags.multi_flags_validator(
(xds_flags.SERVER_XDS_PORT.name, _CMD.name),
message=(
"Run outside of a test suite, must provide"
" the exact port value (must be greater than 0)."
),
)
def _check_server_xds_port_flag(flags_dict):
if flags_dict[_CMD.name] == "cleanup":
return True
return flags_dict[xds_flags.SERVER_XDS_PORT.name] > 0
def _make_sigint_handler(client_runner: common.KubernetesClientRunner):
def sigint_handler(sig, frame):
del sig, frame
print("Caught Ctrl+C. Shutting down the logs")
client_runner.stop_pod_dependencies(log_drain_sec=3)
return sigint_handler
def main(argv):
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
# Must be called before KubernetesApiManager or GcpApiManager init.
xds_flags.set_socket_default_timeout_from_flag()
# Log following and port forwarding.
should_follow_logs = _FOLLOW.value and xds_flags.COLLECT_APP_LOGS.value
should_port_forward = (
should_follow_logs and xds_k8s_flags.DEBUG_USE_PORT_FORWARDING.value
)
enable_workload_identity: bool = (
xds_k8s_flags.ENABLE_WORKLOAD_IDENTITY.value
)
# Setup.
gcp_api_manager = gcp.api.GcpApiManager()
k8s_api_manager = k8s.KubernetesApiManager(xds_k8s_flags.KUBE_CONTEXT.value)
client_namespace = common.make_client_namespace(k8s_api_manager)
client_runner = common.make_client_runner(
client_namespace,
gcp_api_manager,
reuse_namespace=_REUSE_NAMESPACE.value,
mode=_MODE.value,
port_forwarding=should_port_forward,
enable_workload_identity=enable_workload_identity,
)
# Server target
server_target = f"xds:///{xds_flags.SERVER_XDS_HOST.value}"
if xds_flags.SERVER_XDS_PORT.value != 80:
server_target = f"{server_target}:{xds_flags.SERVER_XDS_PORT.value}"
if _CMD.value == "run":
logger.info("Run client, mode=%s", _MODE.value)
client_runner.run(
server_target=server_target,
qps=_QPS.value,
print_response=_PRINT_RESPONSE.value,
secure_mode=_MODE.value == "secure",
config_mesh=_CONFIG_MESH.value,
log_to_stdout=_FOLLOW.value,
)
if should_follow_logs:
print("Following pod logs. Press Ctrl+C top stop")
signal.signal(signal.SIGINT, _make_sigint_handler(client_runner))
signal.pause()
elif _CMD.value == "cleanup":
logger.info("Cleanup client")
client_runner.cleanup(
force=True, force_namespace=_CLEANUP_NAMESPACE.value
)
if __name__ == "__main__":
app.run(main)

@ -1,123 +0,0 @@
# Copyright 2020 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Run test xds server.
Gamma example:
./run.sh bin/run_test_server.py --mode=gamma
"""
import logging
import signal
from absl import app
from absl import flags
from bin.lib import common
from framework import xds_flags
from framework import xds_k8s_flags
from framework.infrastructure import gcp
from framework.infrastructure import k8s
logger = logging.getLogger(__name__)
# Flags
_CMD = flags.DEFINE_enum(
"cmd", default="run", enum_values=["run", "cleanup"], help="Command"
)
_MODE = flags.DEFINE_enum(
"mode",
default="default",
enum_values=["default", "secure", "gamma"],
help="Select server mode",
)
_REUSE_NAMESPACE = flags.DEFINE_bool(
"reuse_namespace", default=True, help="Use existing namespace if exists"
)
_REUSE_SERVICE = flags.DEFINE_bool(
"reuse_service", default=False, help="Use existing service if exists"
)
_FOLLOW = flags.DEFINE_bool(
"follow", default=False, help="Follow pod logs. Requires --collect_app_logs"
)
_CLEANUP_NAMESPACE = flags.DEFINE_bool(
"cleanup_namespace",
default=False,
help="Delete namespace during resource cleanup",
)
flags.adopt_module_key_flags(xds_flags)
flags.adopt_module_key_flags(xds_k8s_flags)
# Running outside of a test suite, so require explicit resource_suffix.
flags.mark_flag_as_required("resource_suffix")
def _make_sigint_handler(server_runner: common.KubernetesServerRunner):
def sigint_handler(sig, frame):
del sig, frame
print("Caught Ctrl+C. Shutting down the logs")
server_runner.stop_pod_dependencies(log_drain_sec=3)
return sigint_handler
def main(argv):
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
# Must be called before KubernetesApiManager or GcpApiManager init.
xds_flags.set_socket_default_timeout_from_flag()
should_follow_logs = _FOLLOW.value and xds_flags.COLLECT_APP_LOGS.value
should_port_forward = (
should_follow_logs and xds_k8s_flags.DEBUG_USE_PORT_FORWARDING.value
)
enable_workload_identity: bool = (
xds_k8s_flags.ENABLE_WORKLOAD_IDENTITY.value
)
# Setup.
gcp_api_manager = gcp.api.GcpApiManager()
k8s_api_manager = k8s.KubernetesApiManager(xds_k8s_flags.KUBE_CONTEXT.value)
server_namespace = common.make_server_namespace(k8s_api_manager)
server_runner = common.make_server_runner(
server_namespace,
gcp_api_manager,
reuse_namespace=_REUSE_NAMESPACE.value,
reuse_service=_REUSE_SERVICE.value,
mode=_MODE.value,
port_forwarding=should_port_forward,
enable_workload_identity=enable_workload_identity,
)
if _CMD.value == "run":
logger.info("Run server, mode=%s", _MODE.value)
server_runner.run(
test_port=xds_flags.SERVER_PORT.value,
maintenance_port=xds_flags.SERVER_MAINTENANCE_PORT.value,
secure_mode=_MODE.value == "secure",
log_to_stdout=_FOLLOW.value,
)
if should_follow_logs:
print("Following pod logs. Press Ctrl+C top stop")
signal.signal(signal.SIGINT, _make_sigint_handler(server_runner))
signal.pause()
elif _CMD.value == "cleanup":
logger.info("Cleanup server")
server_runner.cleanup(
force=True, force_namespace=_CLEANUP_NAMESPACE.value
)
if __name__ == "__main__":
app.run(main)

@ -1,3 +0,0 @@
# Common config file for PSM CSM tests.
--resource_prefix=psm-csm
--noenable_workload_identity

@ -1,11 +0,0 @@
--resource_prefix=psm-interop
--td_bootstrap_image=gcr.io/grpc-testing/td-grpc-bootstrap:7d8d90477792e2e1bfe3a3da20b3dc9ef01d326c
# The canonical implementation of the xDS test server.
# Can be used in tests where language-specific xDS test server does not exist,
# or missing a feature required for the test.
# TODO(sergiitk): Update every ~ 6 months; next 2024-01.
--server_image_canonical=gcr.io/grpc-testing/xds-interop/java-server:canonical-v1.56
--logger_levels=__main__:DEBUG,framework:INFO
--verbosity=0

@ -1,4 +0,0 @@
# Common config file for GAMMA PSM tests.
# TODO(sergiitk): delete when confirmed it's not used
--resource_prefix=psm-gamma
--noenable_workload_identity

@ -1,9 +0,0 @@
--flagfile=config/common.cfg
--project=grpc-testing
--network=default-vpc
--gcp_service_account=xds-k8s-interop-tests@grpc-testing.iam.gserviceaccount.com
--private_api_key_secret_name=projects/830293263384/secrets/xds-interop-tests-private-api-access-key
# Randomize xds port.
--server_xds_port=0
# ResultStore UI doesn't support 256 colors.
--color_style=ansi16

@ -1,62 +0,0 @@
# Copy to local-dev.cfg; replace ${UPPERCASED_VARS}. Details in README.md.
## Import common settings
--flagfile=config/common.cfg
### --------------------------------- Project ----------------------------------
## Project settings
--project=${PROJECT_ID}
--gcp_service_account=${WORKLOAD_SA_EMAIL}
--private_api_key_secret_name=projects/${PROJECT_NUMBER}/secrets/xds-interop-tests-private-api-access-key
### --------------------------------- Clusters ---------------------------------
## The name of kube context to use (points to your GKE cluster).
--kube_context=${KUBE_CONTEXT}
### ------------------------------- App images ---------------------------------
## Test images, f.e. java v1.57.x.
--server_image=gcr.io/grpc-testing/xds-interop/java-server:v1.57.x
--client_image=gcr.io/grpc-testing/xds-interop/java-client:v1.57.x
### ----------------------------------- App ------------------------------------
## Use a resource prefix to describe usage and ownership.
--resource_prefix=${USER}-psm
## Use random port in the server xds address, f.e. xds://my-test-server:42
--server_xds_port=0
## When running ./bin helpers, you might need to set randomly generated fields
## to a static value.
# --resource_suffix=dev
# --server_xds_port=1111
### --------------------------------- Logging ----------------------------------
## Verbosity: -3 (fatal/critical), -2 (error), -1 (warning), 0 (info), 1 (debug)
# --verbosity=1
## Uncomment and set different log levels per module. Examples:
# --logger_levels=__main__:DEBUG,framework:INFO
# --logger_levels=__main__:INFO,framework:DEBUG,urllib3.connectionpool:ERROR
## Uncomment to collect test client, server logs to out/test_app_logs/ folder.
# --collect_app_logs
# --log_dir=out
### ------------------------------- Local dev ---------------------------------
## Enable port forwarding in local dev.
--debug_use_port_forwarding
## (convenience) Allow to set always known flags.
--undefok=private_api_key_secret_name,gcp_ui_url
## Uncomment to create the firewall rule before test case runs.
# --ensure_firewall
## Uncomment if the health check port opened in firewall is different than 8080.
# --server_port=50051

@ -1,15 +0,0 @@
--resource_prefix=interop-psm-url-map
--strategy=reuse
--server_xds_port=8848
# NOTE(lidiz) we pin the server image to java-server because:
# 1. Only Java server understands the rpc-behavior metadata.
# 2. All UrlMap tests today are testing client-side logic.
#
# TODO(sergiitk): Use --server_image_canonical instead.
--server_image=gcr.io/grpc-testing/xds-interop/java-server:canonical-v1.56
# Disables the GCP Workload Identity feature to simplify permission control
--gcp_service_account=None
--private_api_key_secret_name=None
--noenable_workload_identity

@ -1,13 +0,0 @@
# Copyright 2020 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -1,182 +0,0 @@
# Copyright 2022 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional
from framework import xds_k8s_testcase
from framework.helpers import rand as helpers_rand
from framework.infrastructure import k8s
from framework.infrastructure import traffic_director
from framework.test_app.runners.k8s import k8s_xds_client_runner
from framework.test_app.runners.k8s import k8s_xds_server_runner
logger = logging.getLogger(__name__)
# Type aliases
TrafficDirectorManager = traffic_director.TrafficDirectorManager
XdsTestServer = xds_k8s_testcase.XdsTestServer
XdsTestClient = xds_k8s_testcase.XdsTestClient
KubernetesServerRunner = k8s_xds_server_runner.KubernetesServerRunner
KubernetesClientRunner = k8s_xds_client_runner.KubernetesClientRunner
class BootstrapGeneratorBaseTest(xds_k8s_testcase.XdsKubernetesBaseTestCase):
"""Common functionality to support testing of bootstrap generator versions
across gRPC clients and servers."""
@classmethod
def setUpClass(cls):
"""Hook method for setting up class fixture before running tests in
the class.
"""
super().setUpClass()
if cls.server_maintenance_port is None:
cls.server_maintenance_port = (
KubernetesServerRunner.DEFAULT_MAINTENANCE_PORT
)
# Bootstrap generator tests are run as parameterized tests which only
# perform steps specific to the parameterized version of the bootstrap
# generator under test.
#
# Here, we perform setup steps which are common across client and server
# side variants of the bootstrap generator test.
if cls.resource_suffix_randomize:
cls.resource_suffix = helpers_rand.random_resource_suffix()
logger.info(
"Test run resource prefix: %s, suffix: %s",
cls.resource_prefix,
cls.resource_suffix,
)
# TD Manager
cls.td = cls.initTrafficDirectorManager()
# Test namespaces for client and server.
cls.server_namespace = KubernetesServerRunner.make_namespace_name(
cls.resource_prefix, cls.resource_suffix
)
cls.client_namespace = KubernetesClientRunner.make_namespace_name(
cls.resource_prefix, cls.resource_suffix
)
# Ensures the firewall exist
if cls.ensure_firewall:
cls.td.create_firewall_rule(
allowed_ports=cls.firewall_allowed_ports
)
# Randomize xds port, when it's set to 0
if cls.server_xds_port == 0:
# TODO(sergiitk): this is prone to race conditions:
# The port might not me taken now, but there's not guarantee
# it won't be taken until the tests get to creating
# forwarding rule. This check is better than nothing,
# but we should find a better approach.
cls.server_xds_port = cls.td.find_unused_forwarding_rule_port()
logger.info("Found unused xds port: %s", cls.server_xds_port)
# Common TD resources across client and server tests.
cls.td.setup_for_grpc(
cls.server_xds_host,
cls.server_xds_port,
health_check_port=cls.server_maintenance_port,
)
@classmethod
def tearDownClass(cls):
cls.td.cleanup(force=cls.force_cleanup)
super().tearDownClass()
@classmethod
def initTrafficDirectorManager(cls) -> TrafficDirectorManager:
return TrafficDirectorManager(
cls.gcp_api_manager,
project=cls.project,
resource_prefix=cls.resource_prefix,
resource_suffix=cls.resource_suffix,
network=cls.network,
compute_api_version=cls.compute_api_version,
)
@classmethod
def initKubernetesServerRunner(
cls, *, td_bootstrap_image: Optional[str] = None
) -> KubernetesServerRunner:
if not td_bootstrap_image:
td_bootstrap_image = cls.td_bootstrap_image
return KubernetesServerRunner(
k8s.KubernetesNamespace(cls.k8s_api_manager, cls.server_namespace),
deployment_name=cls.server_name,
image_name=cls.server_image,
td_bootstrap_image=td_bootstrap_image,
gcp_project=cls.project,
gcp_api_manager=cls.gcp_api_manager,
gcp_service_account=cls.gcp_service_account,
xds_server_uri=cls.xds_server_uri,
network=cls.network,
debug_use_port_forwarding=cls.debug_use_port_forwarding,
enable_workload_identity=cls.enable_workload_identity,
)
@staticmethod
def startTestServer(
server_runner,
port,
maintenance_port,
xds_host,
xds_port,
replica_count=1,
**kwargs,
) -> XdsTestServer:
test_server = server_runner.run(
replica_count=replica_count,
test_port=port,
maintenance_port=maintenance_port,
**kwargs,
)[0]
test_server.set_xds_address(xds_host, xds_port)
return test_server
def initKubernetesClientRunner(
self, td_bootstrap_image: Optional[str] = None
) -> KubernetesClientRunner:
if not td_bootstrap_image:
td_bootstrap_image = self.td_bootstrap_image
return KubernetesClientRunner(
k8s.KubernetesNamespace(
self.k8s_api_manager, self.client_namespace
),
deployment_name=self.client_name,
image_name=self.client_image,
td_bootstrap_image=td_bootstrap_image,
gcp_project=self.project,
gcp_api_manager=self.gcp_api_manager,
gcp_service_account=self.gcp_service_account,
xds_server_uri=self.xds_server_uri,
network=self.network,
debug_use_port_forwarding=self.debug_use_port_forwarding,
enable_workload_identity=self.enable_workload_identity,
stats_port=self.client_port,
reuse_namespace=self.server_namespace == self.client_namespace,
)
def startTestClient(
self, test_server: XdsTestServer, **kwargs
) -> XdsTestClient:
test_client = self.client_runner.run(
server_target=test_server.xds_uri, **kwargs
)
test_client.wait_for_server_channel_ready()
return test_client

@ -1,58 +0,0 @@
# Copyright 2023 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
# TODO(sergiitk): All custom error classes should extend this.
class FrameworkError(Exception):
"""Base error class for framework errors."""
message: str
kwargs: dict[str, Any]
note: str = ""
def __init__(self, message: str, *args, **kwargs):
self.message = message
# Exception only stores args.
self.kwargs = kwargs
# Pass to the Exception as if message is in **args.
super().__init__(*[message, *args])
# TODO(sergiitk): Remove in py3.11, this will be built-in. See PEP 678.
def add_note(self, note: str):
self.note = note
def __str__(self):
return self.message if not self.note else f"{self.message}\n{self.note}"
@classmethod
def note_blanket_error(cls, reason: str) -> str:
return f"""
Reason: {reason}
{'#' * 80}
# IMPORTANT: This is not a root cause. This is an indication that
# _something_ -- literally _anything_ -- has gone wrong in the xDS flow.
# It is _your_ responsibility to look through the interop client and/or
# server logs to determine what exactly went wrong.
{'#' * 80}
"""
@classmethod
def note_blanket_error_info_below(
cls, reason: str, *, info_below: str
) -> str:
return (
f"{cls.note_blanket_error(reason)}"
f"# Please inspect the information below:\n{info_below}"
)

@ -1,13 +0,0 @@
# Copyright 2020 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -1,79 +0,0 @@
# Copyright 2021 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This contains common helpers for working with dates and time."""
import datetime
import re
from typing import Optional, Pattern
import dateutil.parser
RE_ZERO_OFFSET: Pattern[str] = re.compile(r"[+\-]00:?00$")
def utc_now() -> datetime.datetime:
"""Construct a datetime from current time in UTC timezone."""
return datetime.datetime.now(datetime.timezone.utc)
def shorten_utc_zone(utc_datetime_str: str) -> str:
"""Replace ±00:00 timezone designator with Z (zero offset AKA Zulu time)."""
return RE_ZERO_OFFSET.sub("Z", utc_datetime_str)
def iso8601_utc_time(time: datetime.datetime = None) -> str:
"""Converts datetime UTC and formats as ISO-8601 Zulu time."""
utc_time = time.astimezone(tz=datetime.timezone.utc)
return shorten_utc_zone(utc_time.isoformat())
def iso8601_to_datetime(date_str: str) -> datetime.datetime:
# TODO(sergiitk): use regular datetime.datetime when upgraded to py3.11.
return dateutil.parser.isoparse(date_str)
def datetime_suffix(*, seconds: bool = False) -> str:
"""Return current UTC date, and time in a format useful for resource naming.
Examples:
- 20210626-1859 (seconds=False)
- 20210626-185942 (seconds=True)
Use in resources names incompatible with ISO 8601, e.g. some GCP resources
that only allow lowercase alphanumeric chars and dashes.
Hours and minutes are joined together for better readability, so time is
visually distinct from dash-separated date.
"""
return utc_now().strftime("%Y%m%d-%H%M" + ("%S" if seconds else ""))
def ago(date_from: datetime.datetime, now: Optional[datetime.datetime] = None):
if not now:
now = utc_now()
# Round down microseconds.
date_from = date_from.replace(microsecond=0)
now = now.replace(microsecond=0)
# Calculate the diff.
delta: datetime.timedelta = now - date_from
if delta.days > 1:
result = f"{delta.days} days"
elif delta.days > 0:
result = f"{delta.days} day"
else:
# This case covers negative deltas too.
result = f"{delta} (h:mm:ss)"
return f"{result} ago"

@ -1,204 +0,0 @@
# Copyright 2023 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This contains common helpers for working with grpc data structures."""
import dataclasses
import functools
from typing import Dict, List, Optional
import grpc
import yaml
from framework.rpc import grpc_testing
# Type aliases
RpcsByPeer: Dict[str, int]
RpcMetadata = grpc_testing.LoadBalancerStatsResponse.RpcMetadata
MetadataByPeer: list[str, RpcMetadata]
@functools.cache # pylint: disable=no-member
def status_from_int(grpc_status_int: int) -> Optional[grpc.StatusCode]:
"""Converts the integer gRPC status code to the grpc.StatusCode enum."""
for grpc_status in grpc.StatusCode:
if grpc_status.value[0] == grpc_status_int:
return grpc_status
return None
def status_eq(grpc_status_int: int, grpc_status: grpc.StatusCode) -> bool:
"""Compares the integer gRPC status code with the grpc.StatusCode enum."""
return status_from_int(grpc_status_int) is grpc_status
def status_pretty(grpc_status: grpc.StatusCode) -> str:
"""Formats the status code as (int, NAME), f.e. (4, DEADLINE_EXCEEDED)"""
return f"({grpc_status.value[0]}, {grpc_status.name})"
@dataclasses.dataclass(frozen=True)
class PrettyStatsPerMethod:
# The name of the method.
method: str
# The number of RPCs started for this method, completed and in-flight.
rpcs_started: int
# The number of RPCs that completed with each status for this method.
# Format: status code -> RPC count, f.e.:
# {
# "(0, OK)": 20,
# "(14, UNAVAILABLE)": 10
# }
result: Dict[str, int]
@functools.cached_property # pylint: disable=no-member
def rpcs_completed(self):
"""Returns the total count of competed RPCs across all statuses."""
return sum(self.result.values())
@staticmethod
def from_response(
method_name: str, method_stats: grpc_testing.MethodStats
) -> "PrettyStatsPerMethod":
stats: Dict[str, int] = dict()
for status_int, count in method_stats.result.items():
status: Optional[grpc.StatusCode] = status_from_int(status_int)
status_formatted = status_pretty(status) if status else "None"
stats[status_formatted] = count
return PrettyStatsPerMethod(
method=method_name,
rpcs_started=method_stats.rpcs_started,
result=stats,
)
def accumulated_stats_pretty(
accumulated_stats: grpc_testing.LoadBalancerAccumulatedStatsResponse,
*,
ignore_empty: bool = False,
) -> str:
"""Pretty print LoadBalancerAccumulatedStatsResponse.
Example:
- method: EMPTY_CALL
rpcs_started: 0
result:
(2, UNKNOWN): 20
- method: UNARY_CALL
rpcs_started: 31
result:
(0, OK): 10
(14, UNAVAILABLE): 20
"""
# Only look at stats_per_method, as the other fields are deprecated.
result: List[Dict] = []
for method_name, method_stats in accumulated_stats.stats_per_method.items():
pretty_stats = PrettyStatsPerMethod.from_response(
method_name, method_stats
)
# Skip methods with no RPCs reported when ignore_empty is True.
if ignore_empty and not pretty_stats.rpcs_started:
continue
result.append(dataclasses.asdict(pretty_stats))
return yaml.dump(result, sort_keys=False)
@dataclasses.dataclass(frozen=True)
class PrettyLoadBalancerStats:
# The number of RPCs that failed to record a remote peer.
num_failures: int
# The number of completed RPCs for each peer.
# Format: a dictionary from the host name (str) to the RPC count (int), f.e.
# {"host-a": 10, "host-b": 20}
rpcs_by_peer: "RpcsByPeer"
# The number of completed RPCs per method per each pear.
# Format: a dictionary from the method name to RpcsByPeer (see above), f.e.:
# {
# "UNARY_CALL": {"host-a": 10, "host-b": 20},
# "EMPTY_CALL": {"host-a": 42},
# }
rpcs_by_method: Dict[str, "RpcsByPeer"]
metadatas_by_peer: Dict[str, "MetadataByPeer"]
@staticmethod
def _parse_rpcs_by_peer(
rpcs_by_peer: grpc_testing.RpcsByPeer,
) -> "RpcsByPeer":
result = dict()
for peer, count in rpcs_by_peer.items():
result[peer] = count
return result
@staticmethod
def _parse_metadatas_by_peer(
metadatas_by_peer: grpc_testing.LoadBalancerStatsResponse.MetadataByPeer,
) -> "MetadataByPeer":
result = dict()
for peer, metadatas in metadatas_by_peer.items():
pretty_metadata = ""
for rpc_metadatas in metadatas.rpc_metadata:
for metadata in rpc_metadatas.metadata:
pretty_metadata += (
metadata.key + ": " + metadata.value + ", "
)
result[peer] = pretty_metadata
return result
@classmethod
def from_response(
cls, lb_stats: grpc_testing.LoadBalancerStatsResponse
) -> "PrettyLoadBalancerStats":
rpcs_by_method: Dict[str, "RpcsByPeer"] = dict()
for method_name, stats in lb_stats.rpcs_by_method.items():
if stats:
rpcs_by_method[method_name] = cls._parse_rpcs_by_peer(
stats.rpcs_by_peer
)
return PrettyLoadBalancerStats(
num_failures=lb_stats.num_failures,
rpcs_by_peer=cls._parse_rpcs_by_peer(lb_stats.rpcs_by_peer),
rpcs_by_method=rpcs_by_method,
metadatas_by_peer=cls._parse_metadatas_by_peer(
lb_stats.metadatas_by_peer
),
)
def lb_stats_pretty(lb: grpc_testing.LoadBalancerStatsResponse) -> str:
"""Pretty print LoadBalancerStatsResponse.
Example:
num_failures: 13
rpcs_by_method:
UNARY_CALL:
psm-grpc-server-a: 100
psm-grpc-server-b: 42
EMPTY_CALL:
psm-grpc-server-a: 200
rpcs_by_peer:
psm-grpc-server-a: 200
psm-grpc-server-b: 42
"""
pretty_lb_stats = PrettyLoadBalancerStats.from_response(lb)
stats_as_dict = dataclasses.asdict(pretty_lb_stats)
# Don't print metadatas_by_peer unless it has data
if not stats_as_dict["metadatas_by_peer"]:
stats_as_dict.pop("metadatas_by_peer")
return yaml.dump(stats_as_dict, sort_keys=False)

@ -1,106 +0,0 @@
# Copyright 2021 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The module contains helpers to enable color output in terminals.
Use this to log resources dumped as a structured document (f.e. YAML),
and enable colorful syntax highlighting.
TODO(sergiitk): This can be used to output protobuf responses formatted as JSON.
"""
import logging
from typing import Optional
from absl import flags
import pygments
import pygments.formatter
import pygments.formatters.other
import pygments.formatters.terminal
import pygments.formatters.terminal256
import pygments.lexer
import pygments.lexers.data
import pygments.styles
# The style for terminals supporting 8/16 colors.
STYLE_ANSI_16 = "ansi16"
# Join with pygments styles for terminals supporting 88/256 colors.
ALL_COLOR_STYLES = [STYLE_ANSI_16] + list(pygments.styles.get_all_styles())
# Flags.
COLOR = flags.DEFINE_bool("color", default=True, help="Colorize the output")
COLOR_STYLE = flags.DEFINE_enum(
"color_style",
default="material",
enum_values=ALL_COLOR_STYLES,
help=(
"Color styles for terminals supporting 256 colors. "
f"Use {STYLE_ANSI_16} style for terminals supporting 8/16 colors"
),
)
logger = logging.getLogger(__name__)
# Type aliases.
Lexer = pygments.lexer.Lexer
YamlLexer = pygments.lexers.data.YamlLexer
Formatter = pygments.formatter.Formatter
NullFormatter = pygments.formatters.other.NullFormatter
TerminalFormatter = pygments.formatters.terminal.TerminalFormatter
Terminal256Formatter = pygments.formatters.terminal256.Terminal256Formatter
class Highlighter:
formatter: Formatter
lexer: Lexer
color: bool
color_style: Optional[str] = None
def __init__(
self,
*,
lexer: Lexer,
color: Optional[bool] = None,
color_style: Optional[str] = None,
):
self.lexer = lexer
self.color = color if color is not None else COLOR.value
if self.color:
color_style = color_style if color_style else COLOR_STYLE.value
if color_style not in ALL_COLOR_STYLES:
raise ValueError(
f"Unrecognized color style {color_style}, "
f"valid styles: {ALL_COLOR_STYLES}"
)
if color_style == STYLE_ANSI_16:
# 8/16 colors support only.
self.formatter = TerminalFormatter()
else:
# 88/256 colors.
self.formatter = Terminal256Formatter(style=color_style)
else:
self.formatter = NullFormatter()
def highlight(self, code: str) -> str:
return pygments.highlight(code, self.lexer, self.formatter)
class HighlighterYaml(Highlighter):
def __init__(
self, *, color: Optional[bool] = None, color_style: Optional[str] = None
):
super().__init__(
lexer=YamlLexer(encoding="utf-8"),
color=color,
color_style=color_style,
)

@ -1,48 +0,0 @@
# Copyright 2022 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The module contains helpers to initialize and configure logging."""
import functools
import pathlib
from absl import flags
from absl import logging
def _ensure_flags_parsed() -> None:
if not flags.FLAGS.is_parsed():
raise flags.UnparsedFlagAccessError("Must initialize absl flags first.")
@functools.lru_cache(None)
def log_get_root_dir() -> pathlib.Path:
_ensure_flags_parsed()
log_root = pathlib.Path(logging.find_log_dir()).absolute()
logging.info("Log root dir: %s", log_root)
return log_root
def log_dir_mkdir(name: str) -> pathlib.Path:
"""Creates and returns a subdir with the given name in the log folder."""
if len(pathlib.Path(name).parts) != 1:
raise ValueError(f"Dir name must be a single component; got: {name}")
if ".." in name:
raise ValueError(f"Dir name must not be above the log root.")
log_subdir = log_get_root_dir() / name
if log_subdir.exists() and log_subdir.is_dir():
logging.debug("Using existing log subdir: %s", log_subdir)
else:
log_subdir.mkdir()
logging.debug("Created log subdir: %s", log_subdir)
return log_subdir

@ -1,49 +0,0 @@
# Copyright 2021 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This contains common helpers for generating randomized data."""
import random
import string
import framework.helpers.datetime
# Alphanumeric characters, similar to regex [:alnum:] class, [a-zA-Z0-9]
ALPHANUM = string.ascii_letters + string.digits
# Lowercase alphanumeric characters: [a-z0-9]
# Use ALPHANUM_LOWERCASE alphabet when case-sensitivity is a concern.
ALPHANUM_LOWERCASE = string.ascii_lowercase + string.digits
def rand_string(length: int = 8, *, lowercase: bool = False) -> str:
"""Return random alphanumeric string of given length.
Space for default arguments: alphabet^length
lowercase and uppercase = (26*2 + 10)^8 = 2.18e14 = 218 trillion.
lowercase only = (26 + 10)^8 = 2.8e12 = 2.8 trillion.
"""
alphabet = ALPHANUM_LOWERCASE if lowercase else ALPHANUM
return "".join(random.choices(population=alphabet, k=length))
def random_resource_suffix() -> str:
"""Return a ready-to-use resource suffix with datetime and nonce."""
# Date and time suffix for debugging. Seconds skipped, not as relevant
# Format example: 20210626-1859
datetime_suffix: str = framework.helpers.datetime.datetime_suffix()
# Use lowercase chars because some resource names won't allow uppercase.
# For len 5, total (26 + 10)^5 = 60,466,176 combinations.
# Approx. number of test runs needed to start at the same minute to
# produce a collision: math.sqrt(math.pi/2 * (26+10)**5) ≈ 9745.
# https://en.wikipedia.org/wiki/Birthday_attack#Mathematics
unique_hash: str = rand_string(5, lowercase=True)
return f"{datetime_suffix}-{unique_hash}"

@ -1,273 +0,0 @@
# Copyright 2020 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This contains common retrying helpers (retryers).
We use tenacity as a general-purpose retrying library.
> It [tenacity] originates from a fork of retrying which is sadly no
> longer maintained. Tenacity isnt api compatible with retrying but >
> adds significant new functionality and fixes a number of longstanding bugs.
> - https://tenacity.readthedocs.io/en/latest/index.html
"""
import datetime
import logging
from typing import Any, Callable, List, Optional, Tuple, Type
import tenacity
from tenacity import _utils as tenacity_utils
from tenacity import compat as tenacity_compat
from tenacity import stop
from tenacity import wait
from tenacity.retry import retry_base
retryers_logger = logging.getLogger(__name__)
# Type aliases
timedelta = datetime.timedelta
Retrying = tenacity.Retrying
CheckResultFn = Callable[[Any], bool]
_ExceptionClasses = Tuple[Type[Exception], ...]
def _build_retry_conditions(
*,
retry_on_exceptions: Optional[_ExceptionClasses] = None,
check_result: Optional[CheckResultFn] = None,
) -> List[retry_base]:
# Retry on all exceptions by default
if retry_on_exceptions is None:
retry_on_exceptions = (Exception,)
retry_conditions = [tenacity.retry_if_exception_type(retry_on_exceptions)]
if check_result is not None:
if retry_on_exceptions:
# When retry_on_exceptions is set, also catch them while executing
# check_result callback.
check_result = _safe_check_result(check_result, retry_on_exceptions)
retry_conditions.append(tenacity.retry_if_not_result(check_result))
return retry_conditions
def exponential_retryer_with_timeout(
*,
wait_min: timedelta,
wait_max: timedelta,
timeout: timedelta,
retry_on_exceptions: Optional[_ExceptionClasses] = None,
check_result: Optional[CheckResultFn] = None,
logger: Optional[logging.Logger] = None,
log_level: Optional[int] = logging.DEBUG,
) -> Retrying:
if logger is None:
logger = retryers_logger
if log_level is None:
log_level = logging.DEBUG
retry_conditions = _build_retry_conditions(
retry_on_exceptions=retry_on_exceptions, check_result=check_result
)
retry_error_callback = _on_error_callback(
timeout=timeout, check_result=check_result
)
return Retrying(
retry=tenacity.retry_any(*retry_conditions),
wait=wait.wait_exponential(
min=wait_min.total_seconds(), max=wait_max.total_seconds()
),
stop=stop.stop_after_delay(timeout.total_seconds()),
before_sleep=_before_sleep_log(logger, log_level),
retry_error_callback=retry_error_callback,
)
def constant_retryer(
*,
wait_fixed: timedelta,
attempts: int = 0,
timeout: Optional[timedelta] = None,
retry_on_exceptions: Optional[_ExceptionClasses] = None,
check_result: Optional[CheckResultFn] = None,
logger: Optional[logging.Logger] = None,
log_level: Optional[int] = logging.DEBUG,
) -> Retrying:
if logger is None:
logger = retryers_logger
if log_level is None:
log_level = logging.DEBUG
if attempts < 1 and timeout is None:
raise ValueError("The number of attempts or the timeout must be set")
stops = []
if attempts > 0:
stops.append(stop.stop_after_attempt(attempts))
if timeout is not None:
stops.append(stop.stop_after_delay(timeout.total_seconds()))
retry_conditions = _build_retry_conditions(
retry_on_exceptions=retry_on_exceptions, check_result=check_result
)
retry_error_callback = _on_error_callback(
timeout=timeout, attempts=attempts, check_result=check_result
)
return Retrying(
retry=tenacity.retry_any(*retry_conditions),
wait=wait.wait_fixed(wait_fixed.total_seconds()),
stop=stop.stop_any(*stops),
before_sleep=_before_sleep_log(logger, log_level),
retry_error_callback=retry_error_callback,
)
def _on_error_callback(
*,
timeout: Optional[timedelta] = None,
attempts: int = 0,
check_result: Optional[CheckResultFn] = None,
):
"""A helper to propagate the initial state to the RetryError, so that
it can assemble a helpful message containing timeout/number of attempts.
"""
def error_handler(retry_state: tenacity.RetryCallState):
raise RetryError(
retry_state,
timeout=timeout,
attempts=attempts,
check_result=check_result,
)
return error_handler
def _safe_check_result(
check_result: CheckResultFn, retry_on_exceptions: _ExceptionClasses
) -> CheckResultFn:
"""Wraps check_result callback to catch and handle retry_on_exceptions.
Normally tenacity doesn't retry when retry_if_result/retry_if_not_result
raise an error. This wraps the callback to automatically catch Exceptions
specified in the retry_on_exceptions argument.
Ideally we should make all check_result callbacks to not throw, but
in case it does, we'd rather be annoying in the logs, than break the test.
"""
def _check_result_wrapped(result):
try:
return check_result(result)
except retry_on_exceptions:
retryers_logger.warning(
(
"Result check callback %s raised an exception."
"This shouldn't happen, please handle any exceptions and "
"return return a boolean."
),
tenacity_utils.get_callback_name(check_result),
exc_info=True,
)
return False
return _check_result_wrapped
def _before_sleep_log(logger, log_level, exc_info=False):
"""Same as tenacity.before_sleep_log, but only logs primitive return values.
This is not useful when the return value is a dump of a large object.
"""
def log_it(retry_state):
if retry_state.outcome.failed:
ex = retry_state.outcome.exception()
verb, value = "raised", "%s: %s" % (type(ex).__name__, ex)
if exc_info:
local_exc_info = tenacity_compat.get_exc_info_from_future(
retry_state.outcome
)
else:
local_exc_info = False
else:
local_exc_info = False # exc_info does not apply when no exception
result = retry_state.outcome.result()
if isinstance(result, (int, bool, str)):
verb, value = "returned", result
else:
verb, value = "returned type", type(result)
logger.log(
log_level,
"Retrying %s in %s seconds as it %s %s.",
tenacity_utils.get_callback_name(retry_state.fn),
getattr(retry_state.next_action, "sleep"),
verb,
value,
exc_info=local_exc_info,
)
return log_it
class RetryError(tenacity.RetryError):
# Note: framework.errors.FrameworkError could be used as a mixin,
# but this would rely too much on tenacity.RetryError to not change.
last_attempt: tenacity.Future
note: str = ""
def __init__(
self,
retry_state,
*,
timeout: Optional[timedelta] = None,
attempts: int = 0,
check_result: Optional[CheckResultFn] = None,
):
last_attempt: tenacity.Future = retry_state.outcome
super().__init__(last_attempt)
callback_name = tenacity_utils.get_callback_name(retry_state.fn)
self.message = f"Retry error calling {callback_name}:"
if timeout:
self.message += f" timeout {timeout} (h:mm:ss) exceeded"
if attempts:
self.message += " or"
if attempts:
self.message += f" {attempts} attempts exhausted"
self.message += "."
if last_attempt.failed:
err = last_attempt.exception()
self.message += f" Last exception: {type(err).__name__}: {err}"
elif check_result:
self.message += " Check result callback returned False."
def result(self, *, default=None):
return (
self.last_attempt.result()
if not self.last_attempt.failed
else default
)
def exception(self, *, default=None):
return (
self.last_attempt.exception()
if self.last_attempt.failed
else default
)
# TODO(sergiitk): Remove in py3.11, this will be built-in. See PEP 678.
def add_note(self, note: str):
self.note = note
def __str__(self):
return self.message if not self.note else f"{self.message}\n{self.note}"

@ -1,103 +0,0 @@
# Copyright 2022 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The classes and predicates to assist validate test config for test cases."""
from dataclasses import dataclass
import enum
import logging
import re
from typing import Optional
from packaging import version as pkg_version
logger = logging.getLogger(__name__)
class Lang(enum.Flag):
UNKNOWN = enum.auto()
CPP = enum.auto()
GO = enum.auto()
JAVA = enum.auto()
PYTHON = enum.auto()
NODE = enum.auto()
def __str__(self):
return str(self.name).lower()
@classmethod
def from_string(cls, lang: str):
try:
return cls[lang.upper()]
except KeyError:
return cls.UNKNOWN
@dataclass
class TestConfig:
"""Describes the config for the test suite.
TODO(sergiitk): rename to LangSpec and rename skips.py to lang.py.
"""
client_lang: Lang
server_lang: Lang
version: Optional[str]
def version_gte(self, another: str) -> bool:
"""Returns a bool for whether this VERSION is >= then ANOTHER version.
Special cases:
1) Versions "master" or "dev" are always greater than ANOTHER:
- master > v1.999.x > v1.55.x
- dev > v1.999.x > v1.55.x
- dev == master
2) Versions "dev-VERSION" behave the same as the VERSION:
- dev-master > v1.999.x > v1.55.x
- dev-master == dev == master
- v1.55.x > dev-v1.54.x > v1.53.x
- dev-v1.54.x == v1.54.x
3) Unspecified version (self.version is None) is treated as "master".
"""
if self.version in ("master", "dev", "dev-master", None):
return True
# The left side is not master, so master on the right side wins.
if another == "master":
return False
# Treat "dev-VERSION" on the left side as "VERSION".
version: str = self.version
if version.startswith("dev-"):
version = version[4:]
return self._parse_version(version) >= self._parse_version(another)
def __str__(self):
return (
f"TestConfig(client_lang='{self.client_lang}', "
f"server_lang='{self.server_lang}', version={self.version!r})"
)
@staticmethod
def _parse_version(version: str) -> pkg_version.Version:
if version.endswith(".x"):
version = version[:-2]
return pkg_version.Version(version)
def get_lang(image_name: str) -> Lang:
return Lang.from_string(
re.search(r"/(\w+)-(client|server):", image_name).group(1)
)

@ -1,13 +0,0 @@
# Copyright 2020 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -1,18 +0,0 @@
# Copyright 2020 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from framework.infrastructure.gcp import api
from framework.infrastructure.gcp import compute
from framework.infrastructure.gcp import iam
from framework.infrastructure.gcp import network_security
from framework.infrastructure.gcp import network_services

@ -1,542 +0,0 @@
# Copyright 2020 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import contextlib
import functools
import json
import logging
from typing import Any, Dict, List, Optional
from absl import flags
from google.cloud import secretmanager_v1
from google.longrunning import operations_pb2
from google.protobuf import json_format
from google.rpc import code_pb2
from google.rpc import error_details_pb2
from google.rpc import status_pb2
from googleapiclient import discovery
import googleapiclient.errors
import googleapiclient.http
import tenacity
import yaml
import framework.helpers.highlighter
logger = logging.getLogger(__name__)
PRIVATE_API_KEY_SECRET_NAME = flags.DEFINE_string(
"private_api_key_secret_name",
default=None,
help=(
"Load Private API access key from the latest version of the secret "
"with the given name, in the format projects/*/secrets/*"
),
)
V1_DISCOVERY_URI = flags.DEFINE_string(
"v1_discovery_uri",
default=discovery.V1_DISCOVERY_URI,
help="Override v1 Discovery URI",
)
V2_DISCOVERY_URI = flags.DEFINE_string(
"v2_discovery_uri",
default=discovery.V2_DISCOVERY_URI,
help="Override v2 Discovery URI",
)
COMPUTE_V1_DISCOVERY_FILE = flags.DEFINE_string(
"compute_v1_discovery_file",
default=None,
help="Load compute v1 from discovery file",
)
GCP_UI_URL = flags.DEFINE_string(
"gcp_ui_url",
default="console.cloud.google.com",
help="Override GCP UI URL.",
)
# Type aliases
_HttpError = googleapiclient.errors.HttpError
_HttpLib2Error = googleapiclient.http.httplib2.HttpLib2Error
_HighlighterYaml = framework.helpers.highlighter.HighlighterYaml
Operation = operations_pb2.Operation
HttpRequest = googleapiclient.http.HttpRequest
class GcpApiManager:
def __init__(
self,
*,
v1_discovery_uri=None,
v2_discovery_uri=None,
compute_v1_discovery_file=None,
private_api_key_secret_name=None,
gcp_ui_url=None,
):
self.v1_discovery_uri = v1_discovery_uri or V1_DISCOVERY_URI.value
self.v2_discovery_uri = v2_discovery_uri or V2_DISCOVERY_URI.value
self.compute_v1_discovery_file = (
compute_v1_discovery_file or COMPUTE_V1_DISCOVERY_FILE.value
)
self.private_api_key_secret_name = (
private_api_key_secret_name or PRIVATE_API_KEY_SECRET_NAME.value
)
self.gcp_ui_url = gcp_ui_url or GCP_UI_URL.value
# TODO(sergiitk): add options to pass google Credentials
self._exit_stack = contextlib.ExitStack()
def close(self):
self._exit_stack.close()
@property
@functools.lru_cache(None)
def private_api_key(self):
"""
Private API key.
Return API key credential that identifies a GCP project allow-listed for
accessing private API discovery documents.
https://console.cloud.google.com/apis/credentials
This method lazy-loads the content of the key from the Secret Manager.
https://console.cloud.google.com/security/secret-manager
"""
if not self.private_api_key_secret_name:
raise ValueError(
"private_api_key_secret_name must be set to "
"access private_api_key."
)
secrets_api = self.secrets("v1")
version_resource_path = secrets_api.secret_version_path(
**secrets_api.parse_secret_path(self.private_api_key_secret_name),
secret_version="latest",
)
secret: secretmanager_v1.AccessSecretVersionResponse
secret = secrets_api.access_secret_version(name=version_resource_path)
return secret.payload.data.decode()
@functools.lru_cache(None)
def compute(self, version):
api_name = "compute"
if version == "v1":
if self.compute_v1_discovery_file:
return self._build_from_file(self.compute_v1_discovery_file)
else:
return self._build_from_discovery_v1(api_name, version)
elif version == "v1alpha":
return self._build_from_discovery_v1(api_name, "alpha")
raise NotImplementedError(f"Compute {version} not supported")
@functools.lru_cache(None)
def networksecurity(self, version):
api_name = "networksecurity"
if version == "v1alpha1":
return self._build_from_discovery_v2(
api_name,
version,
api_key=self.private_api_key,
visibility_labels=["NETWORKSECURITY_ALPHA"],
)
elif version == "v1beta1":
return self._build_from_discovery_v2(api_name, version)
raise NotImplementedError(f"Network Security {version} not supported")
@functools.lru_cache(None)
def networkservices(self, version):
api_name = "networkservices"
if version == "v1alpha1":
return self._build_from_discovery_v2(
api_name,
version,
api_key=self.private_api_key,
visibility_labels=["NETWORKSERVICES_ALPHA"],
)
elif version == "v1beta1":
return self._build_from_discovery_v2(api_name, version)
raise NotImplementedError(f"Network Services {version} not supported")
@staticmethod
@functools.lru_cache(None)
def secrets(version: str):
if version == "v1":
return secretmanager_v1.SecretManagerServiceClient()
raise NotImplementedError(f"Secret Manager {version} not supported")
@functools.lru_cache(None)
def iam(self, version: str) -> discovery.Resource:
"""Identity and Access Management (IAM) API.
https://cloud.google.com/iam/docs/reference/rest
https://googleapis.github.io/google-api-python-client/docs/dyn/iam_v1.html
"""
api_name = "iam"
if version == "v1":
return self._build_from_discovery_v1(api_name, version)
raise NotImplementedError(
f"Identity and Access Management (IAM) {version} not supported"
)
def _build_from_discovery_v1(self, api_name, version):
api = discovery.build(
api_name,
version,
cache_discovery=False,
discoveryServiceUrl=self.v1_discovery_uri,
)
self._exit_stack.enter_context(api)
return api
def _build_from_discovery_v2(
self,
api_name,
version,
*,
api_key: Optional[str] = None,
visibility_labels: Optional[List] = None,
):
params = {}
if api_key:
params["key"] = api_key
if visibility_labels:
# Dash-separated list of labels.
params["labels"] = "_".join(visibility_labels)
params_str = ""
if params:
params_str = "&" + "&".join(f"{k}={v}" for k, v in params.items())
api = discovery.build(
api_name,
version,
cache_discovery=False,
discoveryServiceUrl=f"{self.v2_discovery_uri}{params_str}",
)
self._exit_stack.enter_context(api)
return api
def _build_from_file(self, discovery_file):
with open(discovery_file, "r") as f:
api = discovery.build_from_document(f.read())
self._exit_stack.enter_context(api)
return api
class Error(Exception):
"""Base error class for GCP API errors."""
class ResponseError(Error):
"""The response was not a 2xx."""
reason: str
uri: str
error_details: Optional[str]
status: Optional[int]
cause: _HttpError
def __init__(self, cause: _HttpError):
# TODO(sergiitk): cleanup when we upgrade googleapiclient:
# - remove _get_reason()
# - remove error_details note
# - use status_code()
self.reason = cause._get_reason().strip() # noqa
self.uri = cause.uri
self.error_details = cause.error_details # NOTE: Must after _get_reason
self.status = None
if cause.resp and cause.resp.status:
self.status = cause.resp.status
self.cause = cause
super().__init__()
def __repr__(self):
return (
f"<ResponseError {self.status} when requesting {self.uri} "
f'returned "{self.reason}". Details: "{self.error_details}">'
)
class TransportError(Error):
"""A transport error has occurred."""
cause: _HttpLib2Error
def __init__(self, cause: _HttpLib2Error):
self.cause = cause
super().__init__()
def __repr__(self):
return f"<TransportError cause: {self.cause!r}>"
class OperationError(Error):
"""
Operation was not successful.
Assuming Operation based on Google API Style Guide:
https://cloud.google.com/apis/design/design_patterns#long_running_operations
https://github.com/googleapis/googleapis/blob/master/google/longrunning/operations.proto
"""
api_name: str
name: str
metadata: Any
code_name: code_pb2.Code
error: status_pb2.Status
def __init__(self, api_name: str, response: dict):
self.api_name = api_name
# Operation.metadata field is Any specific to the API. It may not be
# present in the default descriptor pool, and that's expected.
# To avoid json_format.ParseError, handle it separately.
self.metadata = response.pop("metadata", {})
# Must be after removing metadata field.
operation: Operation = self._parse_operation_response(response)
self.name = operation.name or "unknown"
self.code_name = code_pb2.Code.Name(operation.error.code)
self.error = operation.error
super().__init__()
@staticmethod
def _parse_operation_response(operation_response: dict) -> Operation:
try:
return json_format.ParseDict(
operation_response,
Operation(),
ignore_unknown_fields=True,
descriptor_pool=error_details_pb2.DESCRIPTOR.pool,
)
except (json_format.Error, TypeError) as e:
# Swallow parsing errors if any. Building correct OperationError()
# is more important than losing debug information. Details still
# can be extracted from the warning.
logger.warning(
(
"Can't parse response while processing OperationError:"
" '%r', error %r"
),
operation_response,
e,
)
return Operation()
def __str__(self):
indent_l1 = " " * 2
indent_l2 = indent_l1 * 2
result = (
f'{self.api_name} operation "{self.name}" failed.\n'
f"{indent_l1}code: {self.error.code} ({self.code_name})\n"
f'{indent_l1}message: "{self.error.message}"'
)
if self.error.details:
result += f"\n{indent_l1}details: [\n"
for any_error in self.error.details:
error_str = json_format.MessageToJson(any_error)
for line in error_str.splitlines():
result += indent_l2 + line + "\n"
result += f"{indent_l1}]"
if self.metadata:
result += f"\n metadata: \n"
metadata_str = json.dumps(self.metadata, indent=2)
for line in metadata_str.splitlines():
result += indent_l2 + line + "\n"
result = result.rstrip()
return result
class GcpProjectApiResource:
# TODO(sergiitk): move someplace better
_WAIT_FOR_OPERATION_SEC = 60 * 10
_WAIT_FIXED_SEC = 2
_GCP_API_RETRIES = 5
def __init__(self, api: discovery.Resource, project: str):
self.api: discovery.Resource = api
self.project: str = project
self._highlighter = _HighlighterYaml()
# TODO(sergiitk): in upcoming GCP refactoring, differentiate between
# _execute for LRO (Long Running Operations), and immediate operations.
def _execute(
self,
request: HttpRequest,
*,
num_retries: Optional[int] = _GCP_API_RETRIES,
) -> Dict[str, Any]:
"""Execute the immediate request.
Returns:
Unmarshalled response as a dictionary.
Raises:
ResponseError if the response was not a 2xx.
TransportError if a transport error has occurred.
"""
if num_retries is None:
num_retries = self._GCP_API_RETRIES
try:
return request.execute(num_retries=num_retries)
except _HttpError as error:
raise ResponseError(error)
except _HttpLib2Error as error:
raise TransportError(error)
def resource_pretty_format(
self,
resource: Any,
*,
highlight: bool = True,
) -> str:
"""Return a string with pretty-printed resource body."""
yaml_out: str = yaml.dump(
resource,
explicit_start=True,
explicit_end=True,
)
return self._highlighter.highlight(yaml_out) if highlight else yaml_out
def resources_pretty_format(
self,
resources: list[Any],
*,
highlight: bool = True,
) -> str:
out = []
for resource in resources:
if hasattr(resource, "name"):
out.append(f"{resource.name}:")
elif "name" in resource:
out.append(f"{resource['name']}:")
out.append(
self.resource_pretty_format(resource, highlight=highlight)
)
return "\n".join(out)
@staticmethod
def wait_for_operation(
operation_request,
test_success_fn,
timeout_sec=_WAIT_FOR_OPERATION_SEC,
wait_sec=_WAIT_FIXED_SEC,
):
retryer = tenacity.Retrying(
retry=(
tenacity.retry_if_not_result(test_success_fn)
| tenacity.retry_if_exception_type()
),
wait=tenacity.wait_fixed(wait_sec),
stop=tenacity.stop_after_delay(timeout_sec),
after=tenacity.after_log(logger, logging.DEBUG),
reraise=True,
)
return retryer(operation_request.execute)
class GcpStandardCloudApiResource(GcpProjectApiResource, metaclass=abc.ABCMeta):
GLOBAL_LOCATION = "global"
def parent(self, location: Optional[str] = GLOBAL_LOCATION):
if location is None:
location = self.GLOBAL_LOCATION
return f"projects/{self.project}/locations/{location}"
def resource_full_name(self, name, collection_name):
return f"{self.parent()}/{collection_name}/{name}"
def _create_resource(
self, collection: discovery.Resource, body: dict, **kwargs
):
logger.info(
"Creating %s resource:\n%s",
self.api_name,
self.resource_pretty_format(body),
)
create_req = collection.create(
parent=self.parent(), body=body, **kwargs
)
self._execute(create_req)
@property
@abc.abstractmethod
def api_name(self) -> str:
raise NotImplementedError
@property
@abc.abstractmethod
def api_version(self) -> str:
raise NotImplementedError
def _get_resource(self, collection: discovery.Resource, full_name):
resource = collection.get(name=full_name).execute()
logger.info(
"Loaded %s:\n%s", full_name, self.resource_pretty_format(resource)
)
return resource
def _delete_resource(
self, collection: discovery.Resource, full_name: str
) -> bool:
logger.debug("Deleting %s", full_name)
try:
self._execute(collection.delete(name=full_name))
return True
except _HttpError as error:
if error.resp and error.resp.status == 404:
logger.debug("%s not deleted since it doesn't exist", full_name)
else:
logger.warning("Failed to delete %s, %r", full_name, error)
return False
# TODO(sergiitk): Use ResponseError and TransportError
def _execute( # pylint: disable=arguments-differ
self,
request: HttpRequest,
timeout_sec: int = GcpProjectApiResource._WAIT_FOR_OPERATION_SEC,
):
operation = request.execute(num_retries=self._GCP_API_RETRIES)
logger.debug("Operation %s", operation)
self._wait(operation["name"], timeout_sec)
def _wait(
self,
operation_id: str,
timeout_sec: int = GcpProjectApiResource._WAIT_FOR_OPERATION_SEC,
):
logger.info(
"Waiting %s sec for %s operation id: %s",
timeout_sec,
self.api_name,
operation_id,
)
op_request = (
self.api.projects().locations().operations().get(name=operation_id)
)
operation = self.wait_for_operation(
operation_request=op_request,
test_success_fn=lambda result: result["done"],
timeout_sec=timeout_sec,
)
logger.debug("Completed operation: %s", operation)
if "error" in operation:
raise OperationError(self.api_name, operation)

@ -1,637 +0,0 @@
# Copyright 2020 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import datetime
import enum
import logging
from typing import Any, Dict, List, Optional, Set
from googleapiclient import discovery
import googleapiclient.errors
import httplib2
import framework.errors
from framework.helpers import retryers
from framework.infrastructure import gcp
logger = logging.getLogger(__name__)
DEBUG_HEADER_IN_RESPONSE = "x-encrypted-debug-headers"
DEBUG_HEADER_KEY = "X-Return-Encrypted-Headers"
class ComputeV1(
gcp.api.GcpProjectApiResource
): # pylint: disable=too-many-public-methods
# TODO(sergiitk): move someplace better
_WAIT_FOR_BACKEND_SEC = 60 * 10
_WAIT_FOR_BACKEND_SLEEP_SEC = 4
_WAIT_FOR_OPERATION_SEC = 60 * 10
gfe_debug_header: Optional[str]
@dataclasses.dataclass(frozen=True)
class GcpResource:
name: str
url: str
@dataclasses.dataclass(frozen=True)
class ZonalGcpResource(GcpResource):
zone: str
def __init__(
self,
api_manager: gcp.api.GcpApiManager,
project: str,
gfe_debug_header: Optional[str] = None,
version: str = "v1",
):
super().__init__(api_manager.compute(version), project)
self.gfe_debug_header = gfe_debug_header
class HealthCheckProtocol(enum.Enum):
TCP = enum.auto()
GRPC = enum.auto()
class BackendServiceProtocol(enum.Enum):
HTTP2 = enum.auto()
GRPC = enum.auto()
def create_health_check(
self,
name: str,
protocol: HealthCheckProtocol,
*,
port: Optional[int] = None,
) -> "GcpResource":
if protocol is self.HealthCheckProtocol.TCP:
health_check_field = "tcpHealthCheck"
elif protocol is self.HealthCheckProtocol.GRPC:
health_check_field = "grpcHealthCheck"
else:
raise TypeError(f"Unexpected Health Check protocol: {protocol}")
health_check_settings = {}
if port is None:
health_check_settings["portSpecification"] = "USE_SERVING_PORT"
else:
health_check_settings["portSpecification"] = "USE_FIXED_PORT"
health_check_settings["port"] = port
return self._insert_resource(
self.api.healthChecks(),
{
"name": name,
"type": protocol.name,
health_check_field: health_check_settings,
},
)
def list_health_check(self):
return self._list_resource(self.api.healthChecks())
def delete_health_check(self, name):
self._delete_resource(self.api.healthChecks(), "healthCheck", name)
def create_firewall_rule(
self,
name: str,
network_url: str,
source_ranges: List[str],
ports: List[str],
) -> Optional["GcpResource"]:
try:
return self._insert_resource(
self.api.firewalls(),
{
"allowed": [{"IPProtocol": "tcp", "ports": ports}],
"direction": "INGRESS",
"name": name,
"network": network_url,
"priority": 1000,
"sourceRanges": source_ranges,
"targetTags": ["allow-health-checks"],
},
)
except googleapiclient.errors.HttpError as http_error:
# TODO(lidiz) use status_code() when we upgrade googleapiclient
if http_error.resp.status == 409:
logger.debug("Firewall rule %s already existed", name)
return None
else:
raise
def delete_firewall_rule(self, name):
self._delete_resource(self.api.firewalls(), "firewall", name)
def create_backend_service_traffic_director(
self,
name: str,
health_check: "GcpResource",
affinity_header: Optional[str] = None,
protocol: Optional[BackendServiceProtocol] = None,
subset_size: Optional[int] = None,
locality_lb_policies: Optional[List[dict]] = None,
outlier_detection: Optional[dict] = None,
) -> "GcpResource":
if not isinstance(protocol, self.BackendServiceProtocol):
raise TypeError(f"Unexpected Backend Service protocol: {protocol}")
body = {
"name": name,
"loadBalancingScheme": "INTERNAL_SELF_MANAGED", # Traffic Director
"healthChecks": [health_check.url],
"protocol": protocol.name,
}
# If affinity header is specified, config the backend service to support
# affinity, and set affinity header to the one given.
if affinity_header:
body["sessionAffinity"] = "HEADER_FIELD"
body["localityLbPolicy"] = "RING_HASH"
body["consistentHash"] = {
"httpHeaderName": affinity_header,
}
if subset_size:
body["subsetting"] = {
"policy": "CONSISTENT_HASH_SUBSETTING",
"subsetSize": subset_size,
}
if locality_lb_policies:
body["localityLbPolicies"] = locality_lb_policies
if outlier_detection:
body["outlierDetection"] = outlier_detection
return self._insert_resource(self.api.backendServices(), body)
def get_backend_service_traffic_director(self, name: str) -> "GcpResource":
return self._get_resource(
self.api.backendServices(), backendService=name
)
def patch_backend_service(self, backend_service, body, **kwargs):
self._patch_resource(
collection=self.api.backendServices(),
backendService=backend_service.name,
body=body,
**kwargs,
)
def backend_service_patch_backends(
self,
backend_service,
backends,
max_rate_per_endpoint: Optional[int] = None,
):
if max_rate_per_endpoint is None:
max_rate_per_endpoint = 5
backend_list = [
{
"group": backend.url,
"balancingMode": "RATE",
"maxRatePerEndpoint": max_rate_per_endpoint,
}
for backend in backends
]
self._patch_resource(
collection=self.api.backendServices(),
body={"backends": backend_list},
backendService=backend_service.name,
)
def backend_service_remove_all_backends(self, backend_service):
self._patch_resource(
collection=self.api.backendServices(),
body={"backends": []},
backendService=backend_service.name,
)
def delete_backend_service(self, name):
self._delete_resource(
self.api.backendServices(), "backendService", name
)
def create_url_map(
self,
name: str,
matcher_name: str,
src_hosts,
dst_default_backend_service: "GcpResource",
dst_host_rule_match_backend_service: Optional["GcpResource"] = None,
) -> "GcpResource":
if dst_host_rule_match_backend_service is None:
dst_host_rule_match_backend_service = dst_default_backend_service
return self._insert_resource(
self.api.urlMaps(),
{
"name": name,
"defaultService": dst_default_backend_service.url,
"hostRules": [
{
"hosts": src_hosts,
"pathMatcher": matcher_name,
}
],
"pathMatchers": [
{
"name": matcher_name,
"defaultService": dst_host_rule_match_backend_service.url,
}
],
},
)
def create_url_map_with_content(self, url_map_body: Any) -> "GcpResource":
return self._insert_resource(self.api.urlMaps(), url_map_body)
def patch_url_map(self, url_map: "GcpResource", body, **kwargs):
self._patch_resource(
collection=self.api.urlMaps(),
urlMap=url_map.name,
body=body,
**kwargs,
)
def delete_url_map(self, name):
self._delete_resource(self.api.urlMaps(), "urlMap", name)
def create_target_grpc_proxy(
self,
name: str,
url_map: "GcpResource",
validate_for_proxyless: bool = True,
) -> "GcpResource":
return self._insert_resource(
self.api.targetGrpcProxies(),
{
"name": name,
"url_map": url_map.url,
"validate_for_proxyless": validate_for_proxyless,
},
)
def delete_target_grpc_proxy(self, name):
self._delete_resource(
self.api.targetGrpcProxies(), "targetGrpcProxy", name
)
def create_target_http_proxy(
self,
name: str,
url_map: "GcpResource",
) -> "GcpResource":
return self._insert_resource(
self.api.targetHttpProxies(),
{
"name": name,
"url_map": url_map.url,
},
)
def delete_target_http_proxy(self, name):
self._delete_resource(
self.api.targetHttpProxies(), "targetHttpProxy", name
)
def create_forwarding_rule(
self,
name: str,
src_port: int,
target_proxy: "GcpResource",
network_url: str,
*,
ip_address: str = "0.0.0.0",
) -> "GcpResource":
return self._insert_resource(
self.api.globalForwardingRules(),
{
"name": name,
"loadBalancingScheme": "INTERNAL_SELF_MANAGED", # Traffic Director
"portRange": src_port,
"IPAddress": ip_address,
"network": network_url,
"target": target_proxy.url,
},
)
def exists_forwarding_rule(self, src_port) -> bool:
# TODO(sergiitk): Better approach for confirming the port is available.
# It's possible a rule allocates actual port range, e.g 8000-9000,
# and this wouldn't catch it. For now, we assume there's no
# port ranges used in the project.
filter_str = (
f'(portRange eq "{src_port}-{src_port}") '
'(IPAddress eq "0.0.0.0")'
'(loadBalancingScheme eq "INTERNAL_SELF_MANAGED")'
)
return self._exists_resource(
self.api.globalForwardingRules(), resource_filter=filter_str
)
def delete_forwarding_rule(self, name):
self._delete_resource(
self.api.globalForwardingRules(), "forwardingRule", name
)
def wait_for_network_endpoint_group(
self,
name: str,
zone: str,
*,
timeout_sec=_WAIT_FOR_BACKEND_SEC,
wait_sec=_WAIT_FOR_BACKEND_SLEEP_SEC,
):
retryer = retryers.constant_retryer(
wait_fixed=datetime.timedelta(seconds=wait_sec),
timeout=datetime.timedelta(seconds=timeout_sec),
check_result=lambda neg: neg and neg.get("size", 0) > 0,
)
network_endpoint_group = retryer(
self._retry_load_network_endpoint_group, name, zone
)
# TODO(sergiitk): dataclass
return self.ZonalGcpResource(
network_endpoint_group["name"],
network_endpoint_group["selfLink"],
zone,
)
def _retry_load_network_endpoint_group(self, name: str, zone: str):
try:
neg = self.get_network_endpoint_group(name, zone)
logger.debug(
"Waiting for endpoints: NEG %s in zone %s, current count %s",
neg["name"],
zone,
neg.get("size"),
)
except googleapiclient.errors.HttpError as error:
# noinspection PyProtectedMember
reason = error._get_reason()
logger.debug(
"Retrying NEG load, got %s, details %s",
error.resp.status,
reason,
)
raise
return neg
def get_network_endpoint_group(self, name, zone):
neg = (
self.api.networkEndpointGroups()
.get(project=self.project, networkEndpointGroup=name, zone=zone)
.execute()
)
# TODO(sergiitk): dataclass
return neg
def wait_for_backends_healthy_status(
self,
backend_service: GcpResource,
backends: Set[ZonalGcpResource],
*,
timeout_sec: int = _WAIT_FOR_BACKEND_SEC,
wait_sec: int = _WAIT_FOR_BACKEND_SLEEP_SEC,
) -> None:
if not backends:
raise ValueError("The list of backends to wait on is empty")
timeout = datetime.timedelta(seconds=timeout_sec)
retryer = retryers.constant_retryer(
wait_fixed=datetime.timedelta(seconds=wait_sec),
timeout=timeout,
check_result=lambda result: result,
)
pending = set(backends)
try:
retryer(self._retry_backends_health, backend_service, pending)
except retryers.RetryError as retry_err:
unhealthy_backends: str = ",".join(
[backend.name for backend in pending]
)
# Attempt to load backend health info for better debug info.
try:
unhealthy = []
# Everything left in pending was unhealthy on the last retry.
for backend in pending:
# It's possible the health status has changed since we
# gave up retrying, but this should be very rare.
health_status = self.get_backend_service_backend_health(
backend_service,
backend,
)
unhealthy.append(
{"name": backend.name, "health_status": health_status}
)
# Override the plain list of unhealthy backend name with
# the one showing the latest backend statuses.
unhealthy_backends = self.resources_pretty_format(
unhealthy,
highlight=False,
)
except Exception as error: # noqa pylint: disable=broad-except
logger.debug(
"Couldn't load backend health info, plain list name"
"will be printed instead. Error: %r",
error,
)
retry_err.add_note(
framework.errors.FrameworkError.note_blanket_error_info_below(
"One or several NEGs (Network Endpoint Groups) didn't"
" report HEALTHY status within expected timeout.",
info_below=(
f"Timeout {timeout} (h:mm:ss) waiting for backend"
f" service '{backend_service.name}' to report all NEGs"
" in the HEALTHY status:"
f" {[backend.name for backend in backends]}."
f"\nUnhealthy backends:\n{unhealthy_backends}"
),
)
)
raise
def _retry_backends_health(
self, backend_service: GcpResource, pending: Set[ZonalGcpResource]
):
for backend in pending:
result = self.get_backend_service_backend_health(
backend_service, backend
)
if "healthStatus" not in result:
logger.debug(
"Waiting for instances: backend %s, zone %s",
backend.name,
backend.zone,
)
continue
backend_healthy = True
for instance in result["healthStatus"]:
logger.debug(
"Backend %s in zone %s: instance %s:%s health: %s",
backend.name,
backend.zone,
instance["ipAddress"],
instance["port"],
instance["healthState"],
)
if instance["healthState"] != "HEALTHY":
backend_healthy = False
if backend_healthy:
logger.info(
"Backend %s in zone %s reported healthy",
backend.name,
backend.zone,
)
pending.remove(backend)
return not pending
def get_backend_service_backend_health(self, backend_service, backend):
return (
self.api.backendServices()
.getHealth(
project=self.project,
backendService=backend_service.name,
body={"group": backend.url},
)
.execute()
)
def _get_resource(
self, collection: discovery.Resource, **kwargs
) -> "GcpResource":
resp = collection.get(project=self.project, **kwargs).execute()
logger.info(
"Loaded compute resource:\n%s", self.resource_pretty_format(resp)
)
return self.GcpResource(resp["name"], resp["selfLink"])
def _exists_resource(
self, collection: discovery.Resource, resource_filter: str
) -> bool:
resp = collection.list(
project=self.project, filter=resource_filter, maxResults=1
).execute(num_retries=self._GCP_API_RETRIES)
if "kind" not in resp:
# TODO(sergiitk): better error
raise ValueError('List response "kind" is missing')
return "items" in resp and resp["items"]
def _insert_resource(
self, collection: discovery.Resource, body: Dict[str, Any]
) -> "GcpResource":
logger.info(
"Creating compute resource:\n%s", self.resource_pretty_format(body)
)
resp = self._execute(collection.insert(project=self.project, body=body))
return self.GcpResource(body["name"], resp["targetLink"])
def _patch_resource(self, collection, body, **kwargs):
logger.info(
"Patching compute resource:\n%s", self.resource_pretty_format(body)
)
self._execute(
collection.patch(project=self.project, body=body, **kwargs)
)
def _list_resource(self, collection: discovery.Resource):
return collection.list(project=self.project).execute(
num_retries=self._GCP_API_RETRIES
)
def _delete_resource(
self,
collection: discovery.Resource,
resource_type: str,
resource_name: str,
) -> bool:
try:
params = {"project": self.project, resource_type: resource_name}
self._execute(collection.delete(**params))
return True
except googleapiclient.errors.HttpError as error:
if error.resp and error.resp.status == 404:
logger.debug(
"Resource %s %s not deleted since it doesn't exist",
resource_type,
resource_name,
)
else:
logger.warning(
'Failed to delete %s "%s", %r',
resource_type,
resource_name,
error,
)
return False
@staticmethod
def _operation_status_done(operation):
return "status" in operation and operation["status"] == "DONE"
@staticmethod
def _log_debug_header(resp: httplib2.Response):
if (
DEBUG_HEADER_IN_RESPONSE in resp
and resp.status >= 300
and resp.status != 404
):
logger.info(
"Received GCP debug headers: %s",
resp[DEBUG_HEADER_IN_RESPONSE],
)
def _execute( # pylint: disable=arguments-differ
self, request, *, timeout_sec=_WAIT_FOR_OPERATION_SEC
):
if self.gfe_debug_header:
logger.debug(
"Adding debug headers for method: %s", request.methodId
)
request.headers[DEBUG_HEADER_KEY] = self.gfe_debug_header
request.add_response_callback(self._log_debug_header)
operation = request.execute(num_retries=self._GCP_API_RETRIES)
logger.debug("Operation %s", operation)
return self._wait(operation["name"], timeout_sec)
def _wait(
self, operation_id: str, timeout_sec: int = _WAIT_FOR_OPERATION_SEC
) -> dict:
logger.info(
"Waiting %s sec for compute operation id: %s",
timeout_sec,
operation_id,
)
# TODO(sergiitk) try using wait() here
# https://googleapis.github.io/google-api-python-client/docs/dyn/compute_v1.globalOperations.html#wait
op_request = self.api.globalOperations().get(
project=self.project, operation=operation_id
)
operation = self.wait_for_operation(
operation_request=op_request,
test_success_fn=self._operation_status_done,
timeout_sec=timeout_sec,
)
logger.debug("Completed operation: %s", operation)
if "error" in operation:
# This shouldn't normally happen: gcp library raises on errors.
raise Exception(
f"Compute operation {operation_id} failed: {operation}"
)
return operation

@ -1,361 +0,0 @@
# Copyright 2021 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import datetime
import functools
import logging
from typing import Any, Dict, FrozenSet, Optional
from framework.helpers import retryers
from framework.infrastructure import gcp
logger = logging.getLogger(__name__)
# Type aliases
_timedelta = datetime.timedelta
_HttpRequest = gcp.api.HttpRequest
class EtagConflict(gcp.api.Error):
"""
Indicates concurrent policy changes.
https://cloud.google.com/iam/docs/policies#etag
"""
def handle_etag_conflict(func):
def wrap_retry_on_etag_conflict(*args, **kwargs):
retryer = retryers.exponential_retryer_with_timeout(
retry_on_exceptions=(EtagConflict, gcp.api.TransportError),
wait_min=_timedelta(seconds=1),
wait_max=_timedelta(seconds=10),
timeout=_timedelta(minutes=2),
)
return retryer(func, *args, **kwargs)
return wrap_retry_on_etag_conflict
def _replace_binding(
policy: "Policy", binding: "Policy.Binding", new_binding: "Policy.Binding"
) -> "Policy":
new_bindings = set(policy.bindings)
new_bindings.discard(binding)
new_bindings.add(new_binding)
# pylint: disable=too-many-function-args # No idea why pylint is like that.
return dataclasses.replace(policy, bindings=frozenset(new_bindings))
@dataclasses.dataclass(frozen=True)
class ServiceAccount:
"""An IAM service account.
https://cloud.google.com/iam/docs/reference/rest/v1/projects.serviceAccounts
Note: "etag" field is skipped because it's deprecated
"""
name: str
projectId: str
uniqueId: str
email: str
oauth2ClientId: str
displayName: str = ""
description: str = ""
disabled: bool = False
@classmethod
def from_response(cls, response: Dict[str, Any]) -> "ServiceAccount":
return cls(
name=response["name"],
projectId=response["projectId"],
uniqueId=response["uniqueId"],
email=response["email"],
oauth2ClientId=response["oauth2ClientId"],
description=response.get("description", ""),
displayName=response.get("displayName", ""),
disabled=response.get("disabled", False),
)
def as_dict(self) -> Dict[str, Any]:
return dataclasses.asdict(self)
@dataclasses.dataclass(frozen=True)
class Expr:
"""
Represents a textual expression in the Common Expression Language syntax.
https://cloud.google.com/iam/docs/reference/rest/v1/Expr
"""
expression: str
title: str = ""
description: str = ""
location: str = ""
@classmethod
def from_response(cls, response: Dict[str, Any]) -> "Expr":
return cls(**response)
def as_dict(self) -> Dict[str, Any]:
return dataclasses.asdict(self)
@dataclasses.dataclass(frozen=True)
class Policy:
"""An Identity and Access Management (IAM) policy, which specifies
access controls for Google Cloud resources.
https://cloud.google.com/iam/docs/reference/rest/v1/Policy
Note: auditConfigs not supported by this implementation.
"""
@dataclasses.dataclass(frozen=True)
class Binding:
"""Policy Binding. Associates members with a role.
https://cloud.google.com/iam/docs/reference/rest/v1/Policy#binding
"""
role: str
members: FrozenSet[str]
condition: Optional[Expr] = None
@classmethod
def from_response(cls, response: Dict[str, Any]) -> "Policy.Binding":
fields = {
"role": response["role"],
"members": frozenset(response.get("members", [])),
}
if "condition" in response:
fields["condition"] = Expr.from_response(response["condition"])
return cls(**fields)
def as_dict(self) -> Dict[str, Any]:
result = {
"role": self.role,
"members": list(self.members),
}
if self.condition is not None:
result["condition"] = self.condition.as_dict()
return result
bindings: FrozenSet[Binding]
etag: str
version: Optional[int] = None
@functools.lru_cache(maxsize=128)
def find_binding_for_role(
self, role: str, condition: Optional[Expr] = None
) -> Optional["Policy.Binding"]:
results = (
binding
for binding in self.bindings
if binding.role == role and binding.condition == condition
)
return next(results, None)
@classmethod
def from_response(cls, response: Dict[str, Any]) -> "Policy":
bindings = frozenset(
cls.Binding.from_response(b) for b in response.get("bindings", [])
)
return cls(
bindings=bindings,
etag=response["etag"],
version=response.get("version"),
)
def as_dict(self) -> Dict[str, Any]:
result = {
"bindings": [binding.as_dict() for binding in self.bindings],
"etag": self.etag,
}
if self.version is not None:
result["version"] = self.version
return result
class IamV1(gcp.api.GcpProjectApiResource):
"""
Identity and Access Management (IAM) API.
https://cloud.google.com/iam/docs/reference/rest
"""
_service_accounts: gcp.api.discovery.Resource
# Operations that affect conditional role bindings must specify version 3.
# Otherwise conditions are omitted, and role names returned with a suffix,
# f.e. roles/iam.workloadIdentityUser_withcond_f1ec33c9beb41857dbf0
# https://cloud.google.com/iam/docs/reference/rest/v1/Policy#FIELDS.version
POLICY_VERSION: int = 3
def __init__(self, api_manager: gcp.api.GcpApiManager, project: str):
super().__init__(api_manager.iam("v1"), project)
# Shortcut to projects/*/serviceAccounts/ endpoints
self._service_accounts = self.api.projects().serviceAccounts()
def service_account_resource_name(self, account) -> str:
"""
Returns full resource name of the service account.
The resource name of the service account in the following format:
projects/{PROJECT_ID}/serviceAccounts/{ACCOUNT}.
The ACCOUNT value can be the email address or the uniqueId of the
service account.
Ref https://cloud.google.com/iam/docs/reference/rest/v1/projects.serviceAccounts/get
Args:
account: The ACCOUNT value
"""
return f"projects/{self.project}/serviceAccounts/{account}"
def get_service_account(self, account: str) -> ServiceAccount:
resource_name = self.service_account_resource_name(account)
request: _HttpRequest = self._service_accounts.get(name=resource_name)
response: Dict[str, Any] = self._execute(request)
logger.debug(
"Loaded Service Account:\n%s", self.resource_pretty_format(response)
)
return ServiceAccount.from_response(response)
def get_service_account_iam_policy(self, account: str) -> Policy:
resource_name = self.service_account_resource_name(account)
request: _HttpRequest = self._service_accounts.getIamPolicy(
resource=resource_name,
options_requestedPolicyVersion=self.POLICY_VERSION,
)
response: Dict[str, Any] = self._execute(request)
logger.debug(
"Loaded Service Account Policy:\n%s",
self.resource_pretty_format(response),
)
return Policy.from_response(response)
def set_service_account_iam_policy(
self, account: str, policy: Policy
) -> Policy:
"""Sets the IAM policy that is attached to a service account.
https://cloud.google.com/iam/docs/reference/rest/v1/projects.serviceAccounts/setIamPolicy
"""
resource_name = self.service_account_resource_name(account)
body = {"policy": policy.as_dict()}
logger.debug(
"Updating Service Account %s policy:\n%s",
account,
self.resource_pretty_format(body),
)
try:
request: _HttpRequest = self._service_accounts.setIamPolicy(
resource=resource_name, body=body
)
response: Dict[str, Any] = self._execute(request)
return Policy.from_response(response)
except gcp.api.ResponseError as error:
if error.status == 409:
# https://cloud.google.com/iam/docs/policies#etag
logger.debug(error)
raise EtagConflict from error
raise
@handle_etag_conflict
def add_service_account_iam_policy_binding(
self, account: str, role: str, member: str
) -> None:
"""Add an IAM policy binding to an IAM service account.
See for details on updating policy bindings:
https://cloud.google.com/iam/docs/reference/rest/v1/projects.serviceAccounts/setIamPolicy
"""
policy: Policy = self.get_service_account_iam_policy(account)
binding: Optional[Policy.Binding] = policy.find_binding_for_role(role)
if binding and member in binding.members:
logger.debug(
"Member %s already has role %s for Service Account %s",
member,
role,
account,
)
return
if binding is None:
updated_binding = Policy.Binding(role, frozenset([member]))
else:
updated_members: FrozenSet[str] = binding.members.union({member})
updated_binding: Policy.Binding = (
dataclasses.replace( # pylint: disable=too-many-function-args
binding, members=updated_members
)
)
updated_policy: Policy = _replace_binding(
policy, binding, updated_binding
)
self.set_service_account_iam_policy(account, updated_policy)
logger.debug(
"Role %s granted to member %s for Service Account %s",
role,
member,
account,
)
@handle_etag_conflict
def remove_service_account_iam_policy_binding(
self, account: str, role: str, member: str
) -> None:
"""Remove an IAM policy binding from the IAM policy of a service
account.
See for details on updating policy bindings:
https://cloud.google.com/iam/docs/reference/rest/v1/projects.serviceAccounts/setIamPolicy
"""
policy: Policy = self.get_service_account_iam_policy(account)
binding: Optional[Policy.Binding] = policy.find_binding_for_role(role)
if binding is None:
logger.debug(
"Noop: Service Account %s has no bindings for role %s",
account,
role,
)
return
if member not in binding.members:
logger.debug(
"Noop: Service Account %s binding for role %s has no member %s",
account,
role,
member,
)
return
updated_members: FrozenSet[str] = binding.members.difference({member})
updated_binding: Policy.Binding = (
dataclasses.replace( # pylint: disable=too-many-function-args
binding, members=updated_members
)
)
updated_policy: Policy = _replace_binding(
policy, binding, updated_binding
)
self.set_service_account_iam_policy(account, updated_policy)
logger.debug(
"Role %s revoked from member %s for Service Account %s",
role,
member,
account,
)

@ -1,221 +0,0 @@
# Copyright 2020 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import dataclasses
import logging
from typing import Any, Dict
from google.rpc import code_pb2
import tenacity
from framework.infrastructure import gcp
logger = logging.getLogger(__name__)
# Type aliases
GcpResource = gcp.compute.ComputeV1.GcpResource
@dataclasses.dataclass(frozen=True)
class ServerTlsPolicy:
url: str
name: str
server_certificate: dict
mtls_policy: dict
update_time: str
create_time: str
@classmethod
def from_response(
cls, name: str, response: Dict[str, Any]
) -> "ServerTlsPolicy":
return cls(
name=name,
url=response["name"],
server_certificate=response.get("serverCertificate", {}),
mtls_policy=response.get("mtlsPolicy", {}),
create_time=response["createTime"],
update_time=response["updateTime"],
)
@dataclasses.dataclass(frozen=True)
class ClientTlsPolicy:
url: str
name: str
client_certificate: dict
server_validation_ca: list
update_time: str
create_time: str
@classmethod
def from_response(
cls, name: str, response: Dict[str, Any]
) -> "ClientTlsPolicy":
return cls(
name=name,
url=response["name"],
client_certificate=response.get("clientCertificate", {}),
server_validation_ca=response.get("serverValidationCa", []),
create_time=response["createTime"],
update_time=response["updateTime"],
)
@dataclasses.dataclass(frozen=True)
class AuthorizationPolicy:
url: str
name: str
update_time: str
create_time: str
action: str
rules: list
@classmethod
def from_response(
cls, name: str, response: Dict[str, Any]
) -> "AuthorizationPolicy":
return cls(
name=name,
url=response["name"],
create_time=response["createTime"],
update_time=response["updateTime"],
action=response["action"],
rules=response.get("rules", []),
)
class _NetworkSecurityBase(
gcp.api.GcpStandardCloudApiResource, metaclass=abc.ABCMeta
):
"""Base class for NetworkSecurity APIs."""
# TODO(https://github.com/grpc/grpc/issues/29532) remove pylint disable
# pylint: disable=abstract-method
def __init__(self, api_manager: gcp.api.GcpApiManager, project: str):
super().__init__(api_manager.networksecurity(self.api_version), project)
# Shortcut to projects/*/locations/ endpoints
self._api_locations = self.api.projects().locations()
@property
def api_name(self) -> str:
return "networksecurity"
def _execute(
self, *args, **kwargs
): # pylint: disable=signature-differs,arguments-differ
# Workaround TD bug: throttled operations are reported as internal.
# Ref b/175345578
retryer = tenacity.Retrying(
retry=tenacity.retry_if_exception(self._operation_internal_error),
wait=tenacity.wait_fixed(10),
stop=tenacity.stop_after_delay(5 * 60),
before_sleep=tenacity.before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
retryer(super()._execute, *args, **kwargs)
@staticmethod
def _operation_internal_error(exception):
return (
isinstance(exception, gcp.api.OperationError)
and exception.error.code == code_pb2.INTERNAL
)
class NetworkSecurityV1Beta1(_NetworkSecurityBase):
"""NetworkSecurity API v1beta1."""
SERVER_TLS_POLICIES = "serverTlsPolicies"
CLIENT_TLS_POLICIES = "clientTlsPolicies"
AUTHZ_POLICIES = "authorizationPolicies"
@property
def api_version(self) -> str:
return "v1beta1"
def create_server_tls_policy(self, name: str, body: dict) -> GcpResource:
return self._create_resource(
collection=self._api_locations.serverTlsPolicies(),
body=body,
serverTlsPolicyId=name,
)
def get_server_tls_policy(self, name: str) -> ServerTlsPolicy:
response = self._get_resource(
collection=self._api_locations.serverTlsPolicies(),
full_name=self.resource_full_name(name, self.SERVER_TLS_POLICIES),
)
return ServerTlsPolicy.from_response(name, response)
def delete_server_tls_policy(self, name: str) -> bool:
return self._delete_resource(
collection=self._api_locations.serverTlsPolicies(),
full_name=self.resource_full_name(name, self.SERVER_TLS_POLICIES),
)
def create_client_tls_policy(self, name: str, body: dict) -> GcpResource:
return self._create_resource(
collection=self._api_locations.clientTlsPolicies(),
body=body,
clientTlsPolicyId=name,
)
def get_client_tls_policy(self, name: str) -> ClientTlsPolicy:
response = self._get_resource(
collection=self._api_locations.clientTlsPolicies(),
full_name=self.resource_full_name(name, self.CLIENT_TLS_POLICIES),
)
return ClientTlsPolicy.from_response(name, response)
def delete_client_tls_policy(self, name: str) -> bool:
return self._delete_resource(
collection=self._api_locations.clientTlsPolicies(),
full_name=self.resource_full_name(name, self.CLIENT_TLS_POLICIES),
)
def create_authz_policy(self, name: str, body: dict) -> GcpResource:
return self._create_resource(
collection=self._api_locations.authorizationPolicies(),
body=body,
authorizationPolicyId=name,
)
def get_authz_policy(self, name: str) -> ClientTlsPolicy:
response = self._get_resource(
collection=self._api_locations.authorizationPolicies(),
full_name=self.resource_full_name(name, self.AUTHZ_POLICIES),
)
return ClientTlsPolicy.from_response(name, response)
def delete_authz_policy(self, name: str) -> bool:
return self._delete_resource(
collection=self._api_locations.authorizationPolicies(),
full_name=self.resource_full_name(name, self.AUTHZ_POLICIES),
)
class NetworkSecurityV1Alpha1(NetworkSecurityV1Beta1):
"""NetworkSecurity API v1alpha1.
Note: extending v1beta1 class presumes that v1beta1 is just a v1alpha1 API
graduated into a more stable version. This is true in most cases. However,
v1alpha1 class can always override and reimplement incompatible methods.
"""
@property
def api_version(self) -> str:
return "v1alpha1"

@ -1,461 +0,0 @@
# Copyright 2020 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import dataclasses
import logging
from typing import Any, Dict, List, Optional, Tuple
from google.rpc import code_pb2
import tenacity
from framework.infrastructure import gcp
logger = logging.getLogger(__name__)
# Type aliases
GcpResource = gcp.compute.ComputeV1.GcpResource
@dataclasses.dataclass(frozen=True)
class EndpointPolicy:
url: str
name: str
type: str
traffic_port_selector: dict
endpoint_matcher: dict
update_time: str
create_time: str
http_filters: Optional[dict] = None
server_tls_policy: Optional[str] = None
@classmethod
def from_response(
cls, name: str, response: Dict[str, Any]
) -> "EndpointPolicy":
return cls(
name=name,
url=response["name"],
type=response["type"],
server_tls_policy=response.get("serverTlsPolicy", None),
traffic_port_selector=response["trafficPortSelector"],
endpoint_matcher=response["endpointMatcher"],
http_filters=response.get("httpFilters", None),
update_time=response["updateTime"],
create_time=response["createTime"],
)
@dataclasses.dataclass(frozen=True)
class Mesh:
name: str
url: str
routes: Optional[List[str]]
@classmethod
def from_response(cls, name: str, d: Dict[str, Any]) -> "Mesh":
return cls(
name=name,
url=d["name"],
routes=list(d["routes"]) if "routes" in d else None,
)
@dataclasses.dataclass(frozen=True)
class GrpcRoute:
@dataclasses.dataclass(frozen=True)
class MethodMatch:
type: Optional[str]
grpc_service: Optional[str]
grpc_method: Optional[str]
case_sensitive: Optional[bool]
@classmethod
def from_response(cls, d: Dict[str, Any]) -> "GrpcRoute.MethodMatch":
return cls(
type=d.get("type"),
grpc_service=d.get("grpcService"),
grpc_method=d.get("grpcMethod"),
case_sensitive=d.get("caseSensitive"),
)
@dataclasses.dataclass(frozen=True)
class HeaderMatch:
type: Optional[str]
key: str
value: str
@classmethod
def from_response(cls, d: Dict[str, Any]) -> "GrpcRoute.HeaderMatch":
return cls(
type=d.get("type"),
key=d["key"],
value=d["value"],
)
@dataclasses.dataclass(frozen=True)
class RouteMatch:
method: Optional["GrpcRoute.MethodMatch"]
headers: Tuple["GrpcRoute.HeaderMatch"]
@classmethod
def from_response(cls, d: Dict[str, Any]) -> "GrpcRoute.RouteMatch":
return cls(
method=GrpcRoute.MethodMatch.from_response(d["method"])
if "method" in d
else None,
headers=tuple(
GrpcRoute.HeaderMatch.from_response(h) for h in d["headers"]
)
if "headers" in d
else (),
)
@dataclasses.dataclass(frozen=True)
class Destination:
service_name: str
weight: Optional[int]
@classmethod
def from_response(cls, d: Dict[str, Any]) -> "GrpcRoute.Destination":
return cls(
service_name=d["serviceName"],
weight=d.get("weight"),
)
@dataclasses.dataclass(frozen=True)
class RouteAction:
@classmethod
def from_response(cls, d: Dict[str, Any]) -> "GrpcRoute.RouteAction":
destinations = (
[
GrpcRoute.Destination.from_response(dest)
for dest in d["destinations"]
]
if "destinations" in d
else []
)
return cls(destinations=destinations)
@dataclasses.dataclass(frozen=True)
class RouteRule:
matches: List["GrpcRoute.RouteMatch"]
action: "GrpcRoute.RouteAction"
@classmethod
def from_response(cls, d: Dict[str, Any]) -> "GrpcRoute.RouteRule":
matches = (
[GrpcRoute.RouteMatch.from_response(m) for m in d["matches"]]
if "matches" in d
else []
)
return cls(
matches=matches,
action=GrpcRoute.RouteAction.from_response(d["action"]),
)
name: str
url: str
hostnames: Tuple[str]
rules: Tuple["GrpcRoute.RouteRule"]
meshes: Optional[Tuple[str]]
@classmethod
def from_response(
cls, name: str, d: Dict[str, Any]
) -> "GrpcRoute.RouteRule":
return cls(
name=name,
url=d["name"],
hostnames=tuple(d["hostnames"]),
rules=tuple(d["rules"]),
meshes=None if d.get("meshes") is None else tuple(d["meshes"]),
)
@dataclasses.dataclass(frozen=True)
class HttpRoute:
@dataclasses.dataclass(frozen=True)
class MethodMatch:
type: Optional[str]
http_service: Optional[str]
http_method: Optional[str]
case_sensitive: Optional[bool]
@classmethod
def from_response(cls, d: Dict[str, Any]) -> "HttpRoute.MethodMatch":
return cls(
type=d.get("type"),
http_service=d.get("httpService"),
http_method=d.get("httpMethod"),
case_sensitive=d.get("caseSensitive"),
)
@dataclasses.dataclass(frozen=True)
class HeaderMatch:
type: Optional[str]
key: str
value: str
@classmethod
def from_response(cls, d: Dict[str, Any]) -> "HttpRoute.HeaderMatch":
return cls(
type=d.get("type"),
key=d["key"],
value=d["value"],
)
@dataclasses.dataclass(frozen=True)
class RouteMatch:
method: Optional["HttpRoute.MethodMatch"]
headers: Tuple["HttpRoute.HeaderMatch"]
@classmethod
def from_response(cls, d: Dict[str, Any]) -> "HttpRoute.RouteMatch":
return cls(
method=HttpRoute.MethodMatch.from_response(d["method"])
if "method" in d
else None,
headers=tuple(
HttpRoute.HeaderMatch.from_response(h) for h in d["headers"]
)
if "headers" in d
else (),
)
@dataclasses.dataclass(frozen=True)
class Destination:
service_name: str
weight: Optional[int]
@classmethod
def from_response(cls, d: Dict[str, Any]) -> "HttpRoute.Destination":
return cls(
service_name=d["serviceName"],
weight=d.get("weight"),
)
@dataclasses.dataclass(frozen=True)
class RouteAction:
destinations: List["HttpRoute.Destination"]
stateful_session_affinity: Optional["HttpRoute.StatefulSessionAffinity"]
@classmethod
def from_response(cls, d: Dict[str, Any]) -> "HttpRoute.RouteAction":
destinations = (
[
HttpRoute.Destination.from_response(dest)
for dest in d["destinations"]
]
if "destinations" in d
else []
)
stateful_session_affinity = (
HttpRoute.StatefulSessionAffinity.from_response(
d["statefulSessionAffinity"]
)
if "statefulSessionAffinity" in d
else None
)
return cls(
destinations=destinations,
stateful_session_affinity=stateful_session_affinity,
)
@dataclasses.dataclass(frozen=True)
class StatefulSessionAffinity:
cookie_ttl: Optional[str]
@classmethod
def from_response(
cls, d: Dict[str, Any]
) -> "HttpRoute.StatefulSessionAffinity":
return cls(cookie_ttl=d.get("cookieTtl"))
@dataclasses.dataclass(frozen=True)
class RouteRule:
matches: List["HttpRoute.RouteMatch"]
action: "HttpRoute.RouteAction"
@classmethod
def from_response(cls, d: Dict[str, Any]) -> "HttpRoute.RouteRule":
matches = (
[HttpRoute.RouteMatch.from_response(m) for m in d["matches"]]
if "matches" in d
else []
)
return cls(
matches=matches,
action=HttpRoute.RouteAction.from_response(d["action"]),
)
name: str
url: str
hostnames: Tuple[str]
rules: Tuple["HttpRoute.RouteRule"]
meshes: Optional[Tuple[str]]
@classmethod
def from_response(cls, name: str, d: Dict[str, Any]) -> "HttpRoute":
return cls(
name=name,
url=d["name"],
hostnames=tuple(d["hostnames"]),
rules=tuple(d["rules"]),
meshes=None if d.get("meshes") is None else tuple(d["meshes"]),
)
class _NetworkServicesBase(
gcp.api.GcpStandardCloudApiResource, metaclass=abc.ABCMeta
):
"""Base class for NetworkServices APIs."""
# TODO(https://github.com/grpc/grpc/issues/29532) remove pylint disable
# pylint: disable=abstract-method
def __init__(self, api_manager: gcp.api.GcpApiManager, project: str):
super().__init__(api_manager.networkservices(self.api_version), project)
# Shortcut to projects/*/locations/ endpoints
self._api_locations = self.api.projects().locations()
@property
def api_name(self) -> str:
return "networkservices"
def _execute(
self, *args, **kwargs
): # pylint: disable=signature-differs,arguments-differ
# Workaround TD bug: throttled operations are reported as internal.
# Ref b/175345578
retryer = tenacity.Retrying(
retry=tenacity.retry_if_exception(self._operation_internal_error),
wait=tenacity.wait_fixed(10),
stop=tenacity.stop_after_delay(5 * 60),
before_sleep=tenacity.before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
retryer(super()._execute, *args, **kwargs)
@staticmethod
def _operation_internal_error(exception):
return (
isinstance(exception, gcp.api.OperationError)
and exception.error.code == code_pb2.INTERNAL
)
class NetworkServicesV1Beta1(_NetworkServicesBase):
"""NetworkServices API v1beta1."""
ENDPOINT_POLICIES = "endpointPolicies"
@property
def api_version(self) -> str:
return "v1beta1"
def create_endpoint_policy(self, name, body: dict) -> GcpResource:
return self._create_resource(
collection=self._api_locations.endpointPolicies(),
body=body,
endpointPolicyId=name,
)
def get_endpoint_policy(self, name: str) -> EndpointPolicy:
response = self._get_resource(
collection=self._api_locations.endpointPolicies(),
full_name=self.resource_full_name(name, self.ENDPOINT_POLICIES),
)
return EndpointPolicy.from_response(name, response)
def delete_endpoint_policy(self, name: str) -> bool:
return self._delete_resource(
collection=self._api_locations.endpointPolicies(),
full_name=self.resource_full_name(name, self.ENDPOINT_POLICIES),
)
class NetworkServicesV1Alpha1(NetworkServicesV1Beta1):
"""NetworkServices API v1alpha1.
Note: extending v1beta1 class presumes that v1beta1 is just a v1alpha1 API
graduated into a more stable version. This is true in most cases. However,
v1alpha1 class can always override and reimplement incompatible methods.
"""
HTTP_ROUTES = "httpRoutes"
GRPC_ROUTES = "grpcRoutes"
MESHES = "meshes"
@property
def api_version(self) -> str:
return "v1alpha1"
def create_mesh(self, name: str, body: dict) -> GcpResource:
return self._create_resource(
collection=self._api_locations.meshes(), body=body, meshId=name
)
def get_mesh(self, name: str) -> Mesh:
full_name = self.resource_full_name(name, self.MESHES)
result = self._get_resource(
collection=self._api_locations.meshes(), full_name=full_name
)
return Mesh.from_response(name, result)
def delete_mesh(self, name: str) -> bool:
return self._delete_resource(
collection=self._api_locations.meshes(),
full_name=self.resource_full_name(name, self.MESHES),
)
def create_grpc_route(self, name: str, body: dict) -> GcpResource:
return self._create_resource(
collection=self._api_locations.grpcRoutes(),
body=body,
grpcRouteId=name,
)
def create_http_route(self, name: str, body: dict) -> GcpResource:
return self._create_resource(
collection=self._api_locations.httpRoutes(),
body=body,
httpRouteId=name,
)
def get_grpc_route(self, name: str) -> GrpcRoute:
full_name = self.resource_full_name(name, self.GRPC_ROUTES)
result = self._get_resource(
collection=self._api_locations.grpcRoutes(), full_name=full_name
)
return GrpcRoute.from_response(name, result)
def get_http_route(self, name: str) -> GrpcRoute:
full_name = self.resource_full_name(name, self.HTTP_ROUTES)
result = self._get_resource(
collection=self._api_locations.httpRoutes(), full_name=full_name
)
return HttpRoute.from_response(name, result)
def delete_grpc_route(self, name: str) -> bool:
return self._delete_resource(
collection=self._api_locations.grpcRoutes(),
full_name=self.resource_full_name(name, self.GRPC_ROUTES),
)
def delete_http_route(self, name: str) -> bool:
return self._delete_resource(
collection=self._api_locations.httpRoutes(),
full_name=self.resource_full_name(name, self.HTTP_ROUTES),
)

@ -1,13 +0,0 @@
# Copyright 2022 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -1,142 +0,0 @@
# Copyright 2022 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import pathlib
import threading
from typing import Any, Callable, Optional, TextIO
from kubernetes import client
from kubernetes.watch import watch
logger = logging.getLogger(__name__)
class PodLogCollector(threading.Thread):
"""A thread that streams logs from the remote pod to a local file."""
pod_name: str
namespace_name: str
stop_event: threading.Event
drain_event: threading.Event
log_path: pathlib.Path
log_to_stdout: bool
log_timestamps: bool
error_backoff_sec: int
_out_stream: Optional[TextIO]
_watcher: Optional[watch.Watch]
_read_pod_log_fn: Callable[..., Any]
def __init__(
self,
*,
pod_name: str,
namespace_name: str,
read_pod_log_fn: Callable[..., Any],
stop_event: threading.Event,
log_path: pathlib.Path,
log_to_stdout: bool = False,
log_timestamps: bool = False,
error_backoff_sec: int = 5,
):
self.pod_name = pod_name
self.namespace_name = namespace_name
self.stop_event = stop_event
# Used to indicate log draining happened. Turned out to be not as useful
# in cases when the logging happens rarely because the blocking happens
# in the native code, which doesn't yield until the next log message.
self.drain_event = threading.Event()
self.log_path = log_path
self.log_to_stdout = log_to_stdout
self.log_timestamps = log_timestamps
self.error_backoff_sec = error_backoff_sec
self._read_pod_log_fn = read_pod_log_fn
self._out_stream = None
self._watcher = None
super().__init__(name=f"pod-log-{pod_name}", daemon=True)
def run(self):
logger.info(
"Starting log collection thread %i for %s",
self.ident,
self.pod_name,
)
try:
self._out_stream = open(
self.log_path, "w", errors="ignore", encoding="utf-8"
)
while not self.stop_event.is_set():
self._stream_log()
finally:
self._stop()
def flush(self):
"""Flushes the log file buffer. May be called from the main thread."""
if self._out_stream:
self._out_stream.flush()
os.fsync(self._out_stream.fileno())
def _stop(self):
if self._watcher is not None:
self._watcher.stop()
self._watcher = None
if self._out_stream is not None:
self._write(
f"Finished log collection for pod {self.pod_name}",
force_flush=True,
)
self._out_stream.close()
self._out_stream = None
self.drain_event.set()
def _stream_log(self):
try:
self._restart_stream()
except client.ApiException as e:
self._write(f"Exception fetching logs: {e}")
self._write(
(
f"Restarting log fetching in {self.error_backoff_sec} sec. "
"Will attempt to read from the beginning, but log "
"truncation may occur."
),
force_flush=True,
)
finally:
# Instead of time.sleep(), we're waiting on the stop event
# in case it gets set earlier.
self.stop_event.wait(timeout=self.error_backoff_sec)
def _restart_stream(self):
self._watcher = watch.Watch()
for msg in self._watcher.stream(
self._read_pod_log_fn,
name=self.pod_name,
namespace=self.namespace_name,
timestamps=self.log_timestamps,
follow=True,
):
self._write(msg)
# Every message check if a stop is requested.
if self.stop_event.is_set():
self._stop()
return
def _write(self, msg: str, force_flush: bool = False):
self._out_stream.write(msg)
self._out_stream.write("\n")
if force_flush:
self.flush()
if self.log_to_stdout:
logger.info(msg)

@ -1,133 +0,0 @@
# Copyright 2022 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
import subprocess
import time
from typing import Optional
logger = logging.getLogger(__name__)
class PortForwardingError(Exception):
"""Error forwarding port"""
class PortForwarder:
PORT_FORWARD_LOCAL_ADDRESS: str = "127.0.0.1"
def __init__(
self,
context: str,
namespace: str,
destination: str,
remote_port: int,
local_port: Optional[int] = None,
local_address: Optional[str] = None,
):
self.context = context
self.namespace = namespace
self.destination = destination
self.remote_port = remote_port
self.local_address = local_address or self.PORT_FORWARD_LOCAL_ADDRESS
self.local_port: Optional[int] = local_port
self.subprocess: Optional[subprocess.Popen] = None
def connect(self) -> None:
if self.local_port:
port_mapping = f"{self.local_port}:{self.remote_port}"
else:
port_mapping = f":{self.remote_port}"
cmd = [
"kubectl",
"--context",
self.context,
"--namespace",
self.namespace,
"port-forward",
"--address",
self.local_address,
self.destination,
port_mapping,
]
logger.debug(
"Executing port forwarding subprocess cmd: %s", " ".join(cmd)
)
self.subprocess = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
universal_newlines=True,
)
# Wait for stdout line indicating successful start.
if self.local_port:
local_port_expected = (
f"Forwarding from {self.local_address}:{self.local_port}"
f" -> {self.remote_port}"
)
else:
local_port_re = re.compile(
f"Forwarding from {self.local_address}:([0-9]+) ->"
f" {self.remote_port}"
)
try:
while True:
time.sleep(0.05)
output = self.subprocess.stdout.readline().strip()
if not output:
return_code = self.subprocess.poll()
if return_code is not None:
errors = [
error
for error in self.subprocess.stdout.readlines()
]
raise PortForwardingError(
"Error forwarding port, kubectl return "
f"code {return_code}, output {errors}"
)
# If there is no output, and the subprocess is not exiting,
# continue waiting for the log line.
continue
# Validate output log
if self.local_port:
if output != local_port_expected:
raise PortForwardingError(
f"Error forwarding port, unexpected output {output}"
)
else:
groups = local_port_re.search(output)
if groups is None:
raise PortForwardingError(
f"Error forwarding port, unexpected output {output}"
)
# Update local port to the randomly picked one
self.local_port = int(groups[1])
logger.info(output)
break
except Exception:
self.close()
raise
def close(self) -> None:
if self.subprocess is not None:
logger.info(
"Shutting down port forwarding, pid %s", self.subprocess.pid
)
self.subprocess.kill()
stdout, _ = self.subprocess.communicate(timeout=5)
logger.info("Port forwarding stopped")
logger.debug("Port forwarding remaining stdout: %s", stdout)
self.subprocess = None

@ -1,22 +0,0 @@
# Copyright 2023 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import framework.infrastructure.traffic_director as td_base
# TODO(sergiitk): [GAMMA] make a TD-manager-less base test case.
class TrafficDirectorGammaManager(td_base.TrafficDirectorManager):
"""Gamma."""
def cleanup(self, *, force=False): # pylint: disable=unused-argument
return True

@ -1,14 +0,0 @@
# Copyright 2020 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from framework.rpc import grpc

@ -1,117 +0,0 @@
# Copyright 2020 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
from typing import Any, Dict, Optional
from google.protobuf import json_format
import google.protobuf.message
import grpc
import framework.errors
logger = logging.getLogger(__name__)
# Type aliases
Message = google.protobuf.message.Message
RpcError = grpc.RpcError
class GrpcClientHelper:
DEFAULT_RPC_DEADLINE_SEC = 90
channel: grpc.Channel
# This is purely cosmetic to make RPC logs look like method calls.
log_service_name: str
# This is purely cosmetic to output the RPC target. Normally set to the
# hostname:port of the remote service, but it doesn't have to be the
# real target. This is done so that when RPC are routed to the proxy
# or port forwarding, this still is set to a useful name.
log_target: str
def __init__(
self,
channel: grpc.Channel,
stub_class: Any,
*,
log_target: Optional[str] = "",
):
self.channel = channel
self.stub = stub_class(channel)
self.log_service_name = re.sub(
"Stub$", "", self.stub.__class__.__name__
)
self.log_target = log_target or ""
def call_unary_with_deadline(
self,
*,
rpc: str,
req: Message,
deadline_sec: Optional[int] = DEFAULT_RPC_DEADLINE_SEC,
log_level: Optional[int] = logging.DEBUG,
) -> Message:
if deadline_sec is None:
deadline_sec = self.DEFAULT_RPC_DEADLINE_SEC
call_kwargs = dict(wait_for_ready=True, timeout=deadline_sec)
self._log_rpc_request(rpc, req, call_kwargs, log_level)
# Call RPC, e.g. RpcStub(channel).RpcMethod(req, ...options)
rpc_callable: grpc.UnaryUnaryMultiCallable = getattr(self.stub, rpc)
return rpc_callable(req, **call_kwargs)
def _log_rpc_request(self, rpc, req, call_kwargs, log_level=logging.DEBUG):
logger.log(
logging.DEBUG if log_level is None else log_level,
"[%s] >> RPC %s.%s(request=%s(%r), %s)",
self.log_target,
self.log_service_name,
rpc,
req.__class__.__name__,
json_format.MessageToDict(req),
", ".join({f"{k}={v}" for k, v in call_kwargs.items()}),
)
class GrpcApp:
channels: Dict[int, grpc.Channel]
class NotFound(framework.errors.FrameworkError):
"""Requested resource not found"""
def __init__(self, rpc_host):
self.rpc_host = rpc_host
# Cache gRPC channels per port
self.channels = dict()
def _make_channel(self, port) -> grpc.Channel:
if port not in self.channels:
target = f"{self.rpc_host}:{port}"
self.channels[port] = grpc.insecure_channel(target)
return self.channels[port]
def close(self):
# Close all channels
for channel in self.channels.values():
channel.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
return False
def __del__(self):
self.close()

@ -1,273 +0,0 @@
# Copyright 2020 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This contains helpers for gRPC services defined in
https://github.com/grpc/grpc-proto/blob/master/grpc/channelz/v1/channelz.proto
"""
import ipaddress
import logging
from typing import Iterator, Optional
import grpc
from grpc_channelz.v1 import channelz_pb2
from grpc_channelz.v1 import channelz_pb2_grpc
import framework.rpc
logger = logging.getLogger(__name__)
# Type aliases
# Channel
Channel = channelz_pb2.Channel
ChannelData = channelz_pb2.ChannelData
ChannelConnectivityState = channelz_pb2.ChannelConnectivityState
ChannelState = ChannelConnectivityState.State # pylint: disable=no-member
_GetTopChannelsRequest = channelz_pb2.GetTopChannelsRequest
_GetTopChannelsResponse = channelz_pb2.GetTopChannelsResponse
# Subchannel
Subchannel = channelz_pb2.Subchannel
_GetSubchannelRequest = channelz_pb2.GetSubchannelRequest
_GetSubchannelResponse = channelz_pb2.GetSubchannelResponse
# Server
Server = channelz_pb2.Server
_GetServersRequest = channelz_pb2.GetServersRequest
_GetServersResponse = channelz_pb2.GetServersResponse
# Sockets
Socket = channelz_pb2.Socket
SocketRef = channelz_pb2.SocketRef
_GetSocketRequest = channelz_pb2.GetSocketRequest
_GetSocketResponse = channelz_pb2.GetSocketResponse
Address = channelz_pb2.Address
Security = channelz_pb2.Security
# Server Sockets
_GetServerSocketsRequest = channelz_pb2.GetServerSocketsRequest
_GetServerSocketsResponse = channelz_pb2.GetServerSocketsResponse
class ChannelzServiceClient(framework.rpc.grpc.GrpcClientHelper):
stub: channelz_pb2_grpc.ChannelzStub
def __init__(
self, channel: grpc.Channel, *, log_target: Optional[str] = ""
):
super().__init__(
channel, channelz_pb2_grpc.ChannelzStub, log_target=log_target
)
@staticmethod
def is_sock_tcpip_address(address: Address):
return address.WhichOneof("address") == "tcpip_address"
@staticmethod
def is_ipv4(tcpip_address: Address.TcpIpAddress):
# According to proto, tcpip_address.ip_address is either IPv4 or IPv6.
# Correspondingly, it's either 4 bytes or 16 bytes in length.
return len(tcpip_address.ip_address) == 4
@classmethod
def sock_address_to_str(cls, address: Address):
if cls.is_sock_tcpip_address(address):
tcpip_address: Address.TcpIpAddress = address.tcpip_address
if cls.is_ipv4(tcpip_address):
ip = ipaddress.IPv4Address(tcpip_address.ip_address)
else:
ip = ipaddress.IPv6Address(tcpip_address.ip_address)
return f"{ip}:{tcpip_address.port}"
else:
raise NotImplementedError("Only tcpip_address implemented")
@classmethod
def sock_addresses_pretty(cls, socket: Socket):
return (
f"local={cls.sock_address_to_str(socket.local)}, "
f"remote={cls.sock_address_to_str(socket.remote)}"
)
@staticmethod
def find_server_socket_matching_client(
server_sockets: Iterator[Socket], client_socket: Socket
) -> Socket:
for server_socket in server_sockets:
if server_socket.remote == client_socket.local:
return server_socket
return None
@staticmethod
def channel_repr(channel: Channel) -> str:
result = f"<Channel channel_id={channel.ref.channel_id}"
if channel.data.target:
result += f" target={channel.data.target}"
result += (
f" call_started={channel.data.calls_started}"
+ f" calls_succeeded={channel.data.calls_succeeded}"
+ f" calls_failed={channel.data.calls_failed}"
)
result += f" state={ChannelState.Name(channel.data.state.state)}>"
return result
@staticmethod
def subchannel_repr(subchannel: Subchannel) -> str:
result = f"<Subchannel subchannel_id={subchannel.ref.subchannel_id}"
if subchannel.data.target:
result += f" target={subchannel.data.target}"
result += f" state={ChannelState.Name(subchannel.data.state.state)}>"
return result
def find_channels_for_target(
self, target: str, **kwargs
) -> Iterator[Channel]:
return (
channel
for channel in self.list_channels(**kwargs)
if channel.data.target == target
)
def find_server_listening_on_port(
self, port: int, **kwargs
) -> Optional[Server]:
for server in self.list_servers(**kwargs):
listen_socket_ref: SocketRef
for listen_socket_ref in server.listen_socket:
listen_socket = self.get_socket(
listen_socket_ref.socket_id, **kwargs
)
listen_address: Address = listen_socket.local
if (
self.is_sock_tcpip_address(listen_address)
and listen_address.tcpip_address.port == port
):
return server
return None
def list_channels(self, **kwargs) -> Iterator[Channel]:
"""
Iterate over all pages of all root channels.
Root channels are those which application has directly created.
This does not include subchannels nor non-top level channels.
"""
start: int = -1
response: Optional[_GetTopChannelsResponse] = None
while start < 0 or not response.end:
# From proto: To request subsequent pages, the client generates this
# value by adding 1 to the highest seen result ID.
start += 1
response = self.call_unary_with_deadline(
rpc="GetTopChannels",
req=_GetTopChannelsRequest(start_channel_id=start),
**kwargs,
)
for channel in response.channel:
start = max(start, channel.ref.channel_id)
yield channel
def get_channel(self, channel_id, **kwargs) -> Channel:
"""Return a single Channel, otherwise raises RpcError."""
response: channelz_pb2.GetChannelResponse
try:
response = self.call_unary_with_deadline(
rpc="GetChannel",
req=channelz_pb2.GetChannelRequest(channel_id=channel_id),
**kwargs,
)
return response.channel
except grpc.RpcError as err:
if isinstance(err, grpc.Call):
# Translate NOT_FOUND into GrpcApp.NotFound.
if err.code() is grpc.StatusCode.NOT_FOUND:
raise framework.rpc.grpc.GrpcApp.NotFound(
f"Channel with channel_id {channel_id} not found",
)
raise
def list_servers(self, **kwargs) -> Iterator[Server]:
"""Iterate over all pages of all servers that exist in the process."""
start: int = -1
response: Optional[_GetServersResponse] = None
while start < 0 or not response.end:
# From proto: To request subsequent pages, the client generates this
# value by adding 1 to the highest seen result ID.
start += 1
response = self.call_unary_with_deadline(
rpc="GetServers",
req=_GetServersRequest(start_server_id=start),
**kwargs,
)
for server in response.server:
start = max(start, server.ref.server_id)
yield server
def list_server_sockets(self, server: Server, **kwargs) -> Iterator[Socket]:
"""List all server sockets that exist in server process.
Iterating over the results will resolve additional pages automatically.
"""
start: int = -1
response: Optional[_GetServerSocketsResponse] = None
while start < 0 or not response.end:
# From proto: To request subsequent pages, the client generates this
# value by adding 1 to the highest seen result ID.
start += 1
response = self.call_unary_with_deadline(
rpc="GetServerSockets",
req=_GetServerSocketsRequest(
server_id=server.ref.server_id, start_socket_id=start
),
**kwargs,
)
socket_ref: SocketRef
for socket_ref in response.socket_ref:
start = max(start, socket_ref.socket_id)
# Yield actual socket
yield self.get_socket(socket_ref.socket_id, **kwargs)
def list_channel_sockets(
self, channel: Channel, **kwargs
) -> Iterator[Socket]:
"""List all sockets of all subchannels of a given channel."""
for subchannel in self.list_channel_subchannels(channel, **kwargs):
yield from self.list_subchannels_sockets(subchannel, **kwargs)
def list_channel_subchannels(
self, channel: Channel, **kwargs
) -> Iterator[Subchannel]:
"""List all subchannels of a given channel."""
for subchannel_ref in channel.subchannel_ref:
yield self.get_subchannel(subchannel_ref.subchannel_id, **kwargs)
def list_subchannels_sockets(
self, subchannel: Subchannel, **kwargs
) -> Iterator[Socket]:
"""List all sockets of a given subchannel."""
for socket_ref in subchannel.socket_ref:
yield self.get_socket(socket_ref.socket_id, **kwargs)
def get_subchannel(self, subchannel_id, **kwargs) -> Subchannel:
"""Return a single Subchannel, otherwise raises RpcError."""
response: _GetSubchannelResponse = self.call_unary_with_deadline(
rpc="GetSubchannel",
req=_GetSubchannelRequest(subchannel_id=subchannel_id),
**kwargs,
)
return response.subchannel
def get_socket(self, socket_id, **kwargs) -> Socket:
"""Return a single Socket, otherwise raises RpcError."""
response: _GetSocketResponse = self.call_unary_with_deadline(
rpc="GetSocket",
req=_GetSocketRequest(socket_id=socket_id),
**kwargs,
)
return response.socket

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save