From 82e506c7b2f632eb8fcf9c5e6a9f4235ef19414d Mon Sep 17 00:00:00 2001 From: Xuan Wang Date: Fri, 28 Jul 2023 13:26:13 -0700 Subject: [PATCH] [python O11Y] Refactor census propagation flow. (#33561) Refactored OpenCensus context propagation flow, now propagation happens for each call and context will be automatically propagated from gRPC server to gRPC client. We're using `execution_context` in OpenCensus since the context is related to OpenCensus and it helps wrap `contextVar` for us. ### Testing * Added a new Bazel test case for context propagation. --- src/python/grpcio/grpc/_server.py | 1 - .../grpc_observability/_gcp_observability.py | 20 ++- .../observability/_observability_test.py | 149 +++++++++++++----- 3 files changed, 120 insertions(+), 50 deletions(-) diff --git a/src/python/grpcio/grpc/_server.py b/src/python/grpcio/grpc/_server.py index 84aae73b246..425aff8c955 100644 --- a/src/python/grpcio/grpc/_server.py +++ b/src/python/grpcio/grpc/_server.py @@ -1258,7 +1258,6 @@ def _start(state: _ServerState) -> None: state.server.start() state.stage = _ServerStage.STARTED _request_call(state) - thread = threading.Thread(target=_serve, args=(state,)) thread.daemon = True thread.start() diff --git a/src/python/grpcio_observability/grpc_observability/_gcp_observability.py b/src/python/grpcio_observability/grpc_observability/_gcp_observability.py index b56f8b9b257..d62653bcd58 100644 --- a/src/python/grpcio_observability/grpc_observability/_gcp_observability.py +++ b/src/python/grpcio_observability/grpc_observability/_gcp_observability.py @@ -56,6 +56,8 @@ GRPC_STATUS_CODE_TO_STRING = { grpc.StatusCode.DATA_LOSS: "DATA_LOSS", } +GRPC_SPAN_CONTEXT = "grpc_span_context" + @dataclass class GcpObservabilityPythonConfig: @@ -165,11 +167,12 @@ class GCPOpenCensusObservability(grpc._observability.ObservabilityPlugin): def create_client_call_tracer( self, method_name: bytes ) -> ClientCallTracerCapsule: - current_span = execution_context.get_current_span() - if current_span: - # Propagate existing OC context - trace_id = current_span.context_tracer.trace_id.encode("utf8") - parent_span_id = current_span.span_id.encode("utf8") + grpc_span_context = execution_context.get_opencensus_attr( + GRPC_SPAN_CONTEXT + ) + if grpc_span_context: + trace_id = grpc_span_context.trace_id.encode("utf8") + parent_span_id = grpc_span_context.span_id.encode("utf8") capsule = _cyobservability.create_client_call_tracer( method_name, trace_id, parent_span_id ) @@ -197,10 +200,11 @@ class GCPOpenCensusObservability(grpc._observability.ObservabilityPlugin): trace_options = trace_options_module.TraceOptions(0) trace_options.set_enabled(is_sampled) span_context = span_context_module.SpanContext( - trace_id=trace_id, span_id=span_id, trace_options=trace_options + trace_id=trace_id, + span_id=span_id, + trace_options=trace_options, ) - current_tracer = execution_context.get_opencensus_tracer() - current_tracer.span_context = span_context + execution_context.set_opencensus_attr(GRPC_SPAN_CONTEXT, span_context) def record_rpc_latency( self, method: str, rpc_latency: float, status_code: grpc.StatusCode diff --git a/src/python/grpcio_tests/tests/observability/_observability_test.py b/src/python/grpcio_tests/tests/observability/_observability_test.py index df657934d45..f7b4fe21a52 100644 --- a/src/python/grpcio_tests/tests/observability/_observability_test.py +++ b/src/python/grpcio_tests/tests/observability/_observability_test.py @@ -35,6 +35,8 @@ _UNARY_STREAM = "/test/UnaryStream" _STREAM_UNARY = "/test/StreamUnary" _STREAM_STREAM = "/test/StreamStream" STREAM_LENGTH = 5 +TRIGGER_RPC_METADATA = ("control", "trigger_rpc") +TRIGGER_RPC_TO_NEW_SERVER_METADATA = ("to_new_server", "") CONFIG_ENV_VAR_NAME = "GRPC_GCP_OBSERVABILITY_CONFIG" CONFIG_FILE_ENV_VAR_NAME = "GRPC_GCP_OBSERVABILITY_CONFIG_FILE" @@ -86,6 +88,19 @@ class TestExporter(_observability.Exporter): def handle_unary_unary(request, servicer_context): + if TRIGGER_RPC_METADATA in servicer_context.invocation_metadata(): + for k, v in servicer_context.invocation_metadata(): + if "port" in k: + unary_unary_call(port=int(v)) + if "to_new_server" in k: + second_server = grpc.server( + futures.ThreadPoolExecutor(max_workers=10) + ) + second_server.add_generic_rpc_handlers((_GenericHandler(),)) + second_server_port = second_server.add_insecure_port("[::]:0") + second_server.start() + unary_unary_call(port=second_server_port) + second_server.stop(0) return _RESPONSE @@ -157,13 +172,57 @@ class ObservabilityTest(unittest.TestCase): exporter=self.test_exporter ): self._start_server() - self.unary_unary_call() + unary_unary_call(port=self._port) self.assertGreater(len(self.all_metric), 0) self.assertGreater(len(self.all_span), 0) self._validate_metrics(self.all_metric) self._validate_spans(self.all_span) + def testContextPropagationToSameServer(self): + # Sends two RPCs, one from gRPC client and the other from gRPC server: + # gRPC Client -> gRPC Server 1 -> gRPC Server 1 + # Verify that the trace_id was propagated to the 2nd RPC. + self._set_config_file(_VALID_CONFIG_TRACING_ONLY) + with grpc_observability.GCPOpenCensusObservability( + exporter=self.test_exporter + ): + port = self._start_server() + metadata = ( + TRIGGER_RPC_METADATA, + ("port", str(port)), + ) + unary_unary_call(port=port, metadata=metadata) + + # 2 of each for ["Recv", "Sent", "Attempt"] + self.assertEqual(len(self.all_span), 6) + trace_id = self.all_span[0].trace_id + for span in self.all_span: + self.assertEqual(span.trace_id, trace_id) + + def testContextPropagationToNewServer(self): + # Sends two RPCs, one from gRPC client and the other from gRPC server: + # gRPC Client -> gRPC Server 1 -> gRPC Server 2 + # Verify that the trace_id was propagated to the 2nd RPC. + # This test case is to make sure that the context from one thread can + # be propagated to different thread. + self._set_config_file(_VALID_CONFIG_TRACING_ONLY) + with grpc_observability.GCPOpenCensusObservability( + exporter=self.test_exporter + ): + port = self._start_server() + metadata = ( + TRIGGER_RPC_METADATA, + TRIGGER_RPC_TO_NEW_SERVER_METADATA, + ) + unary_unary_call(port=port, metadata=metadata) + + # 2 of each for ["Recv", "Sent", "Attempt"] + self.assertEqual(len(self.all_span), 6) + trace_id = self.all_span[0].trace_id + for span in self.all_span: + self.assertEqual(span.trace_id, trace_id) + def testThrowErrorWithoutConfig(self): with self.assertRaises(ValueError): with grpc_observability.GCPOpenCensusObservability( @@ -189,7 +248,7 @@ class ObservabilityTest(unittest.TestCase): exporter=self.test_exporter ): self._start_server() - self.unary_unary_call() + unary_unary_call(port=self._port) self.assertEqual(len(self.all_metric), 0) self.assertEqual(len(self.all_span), 0) @@ -208,7 +267,7 @@ class ObservabilityTest(unittest.TestCase): exporter=self.test_exporter ): self._start_server() - self.unary_unary_call() + unary_unary_call(port=self._port) self.assertEqual(len(self.all_span), 0) self.assertGreater(len(self.all_metric), 0) @@ -220,7 +279,7 @@ class ObservabilityTest(unittest.TestCase): exporter=self.test_exporter ): self._start_server() - self.unary_unary_call() + unary_unary_call(port=self._port) self.assertEqual(len(self.all_metric), 0) self.assertGreater(len(self.all_span), 0) @@ -232,7 +291,7 @@ class ObservabilityTest(unittest.TestCase): exporter=self.test_exporter ): self._start_server() - self.unary_stream_call() + unary_stream_call(port=self._port) self.assertGreater(len(self.all_metric), 0) self.assertGreater(len(self.all_span), 0) @@ -245,7 +304,7 @@ class ObservabilityTest(unittest.TestCase): exporter=self.test_exporter ): self._start_server() - self.stream_unary_call() + stream_unary_call(port=self._port) self.assertTrue(len(self.all_metric) > 0) self.assertTrue(len(self.all_span) > 0) @@ -258,7 +317,7 @@ class ObservabilityTest(unittest.TestCase): exporter=self.test_exporter ): self._start_server() - self.stream_stream_call() + stream_stream_call(port=self._port) self.assertGreater(len(self.all_metric), 0) self.assertGreater(len(self.all_span), 0) @@ -268,7 +327,7 @@ class ObservabilityTest(unittest.TestCase): def testNoRecordBeforeInit(self): self._set_config_file(_VALID_CONFIG_TRACING_STATS) self._start_server() - self.unary_unary_call() + unary_unary_call(port=self._port) self.assertEqual(len(self.all_metric), 0) self.assertEqual(len(self.all_span), 0) self._server.stop(0) @@ -277,7 +336,7 @@ class ObservabilityTest(unittest.TestCase): exporter=self.test_exporter ): self._start_server() - self.unary_unary_call() + unary_unary_call(port=self._port) self.assertGreater(len(self.all_metric), 0) self.assertGreater(len(self.all_span), 0) @@ -290,7 +349,7 @@ class ObservabilityTest(unittest.TestCase): exporter=self.test_exporter ): self._start_server() - self.unary_unary_call() + unary_unary_call(port=self._port) self.assertGreater(len(self.all_metric), 0) self.assertGreater(len(self.all_span), 0) @@ -299,7 +358,7 @@ class ObservabilityTest(unittest.TestCase): self._validate_metrics(self.all_metric) self._validate_spans(self.all_span) - self.unary_unary_call() + unary_unary_call(port=self._port) self.assertEqual(len(self.all_metric), current_metric_len) self.assertEqual(len(self.all_span), current_spans_len) @@ -320,8 +379,7 @@ class ObservabilityTest(unittest.TestCase): ): self._start_server() for _ in range(_CALLS): - self.unary_unary_call() - + unary_unary_call(port=self._port) self.assertEqual(len(self.all_metric), 0) self.assertGreaterEqual(len(self.all_span), _LOWER_BOUND) self.assertLessEqual(len(self.all_span), _HIGHER_BOUND) @@ -337,7 +395,7 @@ class ObservabilityTest(unittest.TestCase): exporter=self.test_exporter ): self._start_server() - self.unary_unary_call() + unary_unary_call(port=self._port) self.assertEqual(len(self.all_metric), 0) self.assertGreater(len(self.all_span), 0) @@ -350,38 +408,12 @@ class ObservabilityTest(unittest.TestCase): f.write(json.dumps(config)) os.environ[CONFIG_FILE_ENV_VAR_NAME] = config_file_path - def unary_unary_call(self): - with grpc.insecure_channel(f"localhost:{self._port}") as channel: - multi_callable = channel.unary_unary(_UNARY_UNARY) - unused_response, call = multi_callable.with_call(_REQUEST) - - def unary_stream_call(self): - with grpc.insecure_channel(f"localhost:{self._port}") as channel: - multi_callable = channel.unary_stream(_UNARY_STREAM) - call = multi_callable(_REQUEST) - for _ in call: - pass - - def stream_unary_call(self): - with grpc.insecure_channel(f"localhost:{self._port}") as channel: - multi_callable = channel.stream_unary(_STREAM_UNARY) - unused_response, call = multi_callable.with_call( - iter([_REQUEST] * STREAM_LENGTH) - ) - - def stream_stream_call(self): - with grpc.insecure_channel(f"localhost:{self._port}") as channel: - multi_callable = channel.stream_stream(_STREAM_STREAM) - call = multi_callable(iter([_REQUEST] * STREAM_LENGTH)) - for _ in call: - pass - def _start_server(self) -> None: self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) self._server.add_generic_rpc_handlers((_GenericHandler(),)) self._port = self._server.add_insecure_port("[::]:0") - self._server.start() + return self._port def _validate_metrics( self, metrics: List[_observability.StatsData] @@ -413,6 +445,41 @@ class ObservabilityTest(unittest.TestCase): self.assertTrue(prefix_exist) +def unary_unary_call(port, metadata=None): + with grpc.insecure_channel(f"localhost:{port}") as channel: + multi_callable = channel.unary_unary(_UNARY_UNARY) + if metadata: + unused_response, call = multi_callable.with_call( + _REQUEST, metadata=metadata + ) + else: + unused_response, call = multi_callable.with_call(_REQUEST) + + +def unary_stream_call(port): + with grpc.insecure_channel(f"localhost:{port}") as channel: + multi_callable = channel.unary_stream(_UNARY_STREAM) + call = multi_callable(_REQUEST) + for _ in call: + pass + + +def stream_unary_call(port): + with grpc.insecure_channel(f"localhost:{port}") as channel: + multi_callable = channel.stream_unary(_STREAM_UNARY) + unused_response, call = multi_callable.with_call( + iter([_REQUEST] * STREAM_LENGTH) + ) + + +def stream_stream_call(port): + with grpc.insecure_channel(f"localhost:{port}") as channel: + multi_callable = channel.stream_stream(_STREAM_STREAM) + call = multi_callable(iter([_REQUEST] * STREAM_LENGTH)) + for _ in call: + pass + + if __name__ == "__main__": logging.basicConfig() unittest.main(verbosity=2)