[python O11Y] Refactor census propagation flow. (#33561)

Refactored OpenCensus context propagation flow, now propagation happens
for each call and context will be automatically propagated from gRPC
server to gRPC client.

We're using `execution_context` in OpenCensus since the context is
related to OpenCensus and it helps wrap `contextVar` for us.

### Testing
* Added a new Bazel test case for context propagation. 

<!--

If you know who should review your pull request, please assign it to
that
person, otherwise the pull request would get assigned randomly.

If your pull request is for a specific language, please add the
appropriate
lang label.

-->
pull/33919/head
Xuan Wang 2 years ago committed by GitHub
parent 6b2de0fa4b
commit 82e506c7b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      src/python/grpcio/grpc/_server.py
  2. 20
      src/python/grpcio_observability/grpc_observability/_gcp_observability.py
  3. 149
      src/python/grpcio_tests/tests/observability/_observability_test.py

@ -1258,7 +1258,6 @@ def _start(state: _ServerState) -> None:
state.server.start()
state.stage = _ServerStage.STARTED
_request_call(state)
thread = threading.Thread(target=_serve, args=(state,))
thread.daemon = True
thread.start()

@ -56,6 +56,8 @@ GRPC_STATUS_CODE_TO_STRING = {
grpc.StatusCode.DATA_LOSS: "DATA_LOSS",
}
GRPC_SPAN_CONTEXT = "grpc_span_context"
@dataclass
class GcpObservabilityPythonConfig:
@ -165,11 +167,12 @@ class GCPOpenCensusObservability(grpc._observability.ObservabilityPlugin):
def create_client_call_tracer(
self, method_name: bytes
) -> ClientCallTracerCapsule:
current_span = execution_context.get_current_span()
if current_span:
# Propagate existing OC context
trace_id = current_span.context_tracer.trace_id.encode("utf8")
parent_span_id = current_span.span_id.encode("utf8")
grpc_span_context = execution_context.get_opencensus_attr(
GRPC_SPAN_CONTEXT
)
if grpc_span_context:
trace_id = grpc_span_context.trace_id.encode("utf8")
parent_span_id = grpc_span_context.span_id.encode("utf8")
capsule = _cyobservability.create_client_call_tracer(
method_name, trace_id, parent_span_id
)
@ -197,10 +200,11 @@ class GCPOpenCensusObservability(grpc._observability.ObservabilityPlugin):
trace_options = trace_options_module.TraceOptions(0)
trace_options.set_enabled(is_sampled)
span_context = span_context_module.SpanContext(
trace_id=trace_id, span_id=span_id, trace_options=trace_options
trace_id=trace_id,
span_id=span_id,
trace_options=trace_options,
)
current_tracer = execution_context.get_opencensus_tracer()
current_tracer.span_context = span_context
execution_context.set_opencensus_attr(GRPC_SPAN_CONTEXT, span_context)
def record_rpc_latency(
self, method: str, rpc_latency: float, status_code: grpc.StatusCode

@ -35,6 +35,8 @@ _UNARY_STREAM = "/test/UnaryStream"
_STREAM_UNARY = "/test/StreamUnary"
_STREAM_STREAM = "/test/StreamStream"
STREAM_LENGTH = 5
TRIGGER_RPC_METADATA = ("control", "trigger_rpc")
TRIGGER_RPC_TO_NEW_SERVER_METADATA = ("to_new_server", "")
CONFIG_ENV_VAR_NAME = "GRPC_GCP_OBSERVABILITY_CONFIG"
CONFIG_FILE_ENV_VAR_NAME = "GRPC_GCP_OBSERVABILITY_CONFIG_FILE"
@ -86,6 +88,19 @@ class TestExporter(_observability.Exporter):
def handle_unary_unary(request, servicer_context):
if TRIGGER_RPC_METADATA in servicer_context.invocation_metadata():
for k, v in servicer_context.invocation_metadata():
if "port" in k:
unary_unary_call(port=int(v))
if "to_new_server" in k:
second_server = grpc.server(
futures.ThreadPoolExecutor(max_workers=10)
)
second_server.add_generic_rpc_handlers((_GenericHandler(),))
second_server_port = second_server.add_insecure_port("[::]:0")
second_server.start()
unary_unary_call(port=second_server_port)
second_server.stop(0)
return _RESPONSE
@ -157,13 +172,57 @@ class ObservabilityTest(unittest.TestCase):
exporter=self.test_exporter
):
self._start_server()
self.unary_unary_call()
unary_unary_call(port=self._port)
self.assertGreater(len(self.all_metric), 0)
self.assertGreater(len(self.all_span), 0)
self._validate_metrics(self.all_metric)
self._validate_spans(self.all_span)
def testContextPropagationToSameServer(self):
# Sends two RPCs, one from gRPC client and the other from gRPC server:
# gRPC Client -> gRPC Server 1 -> gRPC Server 1
# Verify that the trace_id was propagated to the 2nd RPC.
self._set_config_file(_VALID_CONFIG_TRACING_ONLY)
with grpc_observability.GCPOpenCensusObservability(
exporter=self.test_exporter
):
port = self._start_server()
metadata = (
TRIGGER_RPC_METADATA,
("port", str(port)),
)
unary_unary_call(port=port, metadata=metadata)
# 2 of each for ["Recv", "Sent", "Attempt"]
self.assertEqual(len(self.all_span), 6)
trace_id = self.all_span[0].trace_id
for span in self.all_span:
self.assertEqual(span.trace_id, trace_id)
def testContextPropagationToNewServer(self):
# Sends two RPCs, one from gRPC client and the other from gRPC server:
# gRPC Client -> gRPC Server 1 -> gRPC Server 2
# Verify that the trace_id was propagated to the 2nd RPC.
# This test case is to make sure that the context from one thread can
# be propagated to different thread.
self._set_config_file(_VALID_CONFIG_TRACING_ONLY)
with grpc_observability.GCPOpenCensusObservability(
exporter=self.test_exporter
):
port = self._start_server()
metadata = (
TRIGGER_RPC_METADATA,
TRIGGER_RPC_TO_NEW_SERVER_METADATA,
)
unary_unary_call(port=port, metadata=metadata)
# 2 of each for ["Recv", "Sent", "Attempt"]
self.assertEqual(len(self.all_span), 6)
trace_id = self.all_span[0].trace_id
for span in self.all_span:
self.assertEqual(span.trace_id, trace_id)
def testThrowErrorWithoutConfig(self):
with self.assertRaises(ValueError):
with grpc_observability.GCPOpenCensusObservability(
@ -189,7 +248,7 @@ class ObservabilityTest(unittest.TestCase):
exporter=self.test_exporter
):
self._start_server()
self.unary_unary_call()
unary_unary_call(port=self._port)
self.assertEqual(len(self.all_metric), 0)
self.assertEqual(len(self.all_span), 0)
@ -208,7 +267,7 @@ class ObservabilityTest(unittest.TestCase):
exporter=self.test_exporter
):
self._start_server()
self.unary_unary_call()
unary_unary_call(port=self._port)
self.assertEqual(len(self.all_span), 0)
self.assertGreater(len(self.all_metric), 0)
@ -220,7 +279,7 @@ class ObservabilityTest(unittest.TestCase):
exporter=self.test_exporter
):
self._start_server()
self.unary_unary_call()
unary_unary_call(port=self._port)
self.assertEqual(len(self.all_metric), 0)
self.assertGreater(len(self.all_span), 0)
@ -232,7 +291,7 @@ class ObservabilityTest(unittest.TestCase):
exporter=self.test_exporter
):
self._start_server()
self.unary_stream_call()
unary_stream_call(port=self._port)
self.assertGreater(len(self.all_metric), 0)
self.assertGreater(len(self.all_span), 0)
@ -245,7 +304,7 @@ class ObservabilityTest(unittest.TestCase):
exporter=self.test_exporter
):
self._start_server()
self.stream_unary_call()
stream_unary_call(port=self._port)
self.assertTrue(len(self.all_metric) > 0)
self.assertTrue(len(self.all_span) > 0)
@ -258,7 +317,7 @@ class ObservabilityTest(unittest.TestCase):
exporter=self.test_exporter
):
self._start_server()
self.stream_stream_call()
stream_stream_call(port=self._port)
self.assertGreater(len(self.all_metric), 0)
self.assertGreater(len(self.all_span), 0)
@ -268,7 +327,7 @@ class ObservabilityTest(unittest.TestCase):
def testNoRecordBeforeInit(self):
self._set_config_file(_VALID_CONFIG_TRACING_STATS)
self._start_server()
self.unary_unary_call()
unary_unary_call(port=self._port)
self.assertEqual(len(self.all_metric), 0)
self.assertEqual(len(self.all_span), 0)
self._server.stop(0)
@ -277,7 +336,7 @@ class ObservabilityTest(unittest.TestCase):
exporter=self.test_exporter
):
self._start_server()
self.unary_unary_call()
unary_unary_call(port=self._port)
self.assertGreater(len(self.all_metric), 0)
self.assertGreater(len(self.all_span), 0)
@ -290,7 +349,7 @@ class ObservabilityTest(unittest.TestCase):
exporter=self.test_exporter
):
self._start_server()
self.unary_unary_call()
unary_unary_call(port=self._port)
self.assertGreater(len(self.all_metric), 0)
self.assertGreater(len(self.all_span), 0)
@ -299,7 +358,7 @@ class ObservabilityTest(unittest.TestCase):
self._validate_metrics(self.all_metric)
self._validate_spans(self.all_span)
self.unary_unary_call()
unary_unary_call(port=self._port)
self.assertEqual(len(self.all_metric), current_metric_len)
self.assertEqual(len(self.all_span), current_spans_len)
@ -320,8 +379,7 @@ class ObservabilityTest(unittest.TestCase):
):
self._start_server()
for _ in range(_CALLS):
self.unary_unary_call()
unary_unary_call(port=self._port)
self.assertEqual(len(self.all_metric), 0)
self.assertGreaterEqual(len(self.all_span), _LOWER_BOUND)
self.assertLessEqual(len(self.all_span), _HIGHER_BOUND)
@ -337,7 +395,7 @@ class ObservabilityTest(unittest.TestCase):
exporter=self.test_exporter
):
self._start_server()
self.unary_unary_call()
unary_unary_call(port=self._port)
self.assertEqual(len(self.all_metric), 0)
self.assertGreater(len(self.all_span), 0)
@ -350,38 +408,12 @@ class ObservabilityTest(unittest.TestCase):
f.write(json.dumps(config))
os.environ[CONFIG_FILE_ENV_VAR_NAME] = config_file_path
def unary_unary_call(self):
with grpc.insecure_channel(f"localhost:{self._port}") as channel:
multi_callable = channel.unary_unary(_UNARY_UNARY)
unused_response, call = multi_callable.with_call(_REQUEST)
def unary_stream_call(self):
with grpc.insecure_channel(f"localhost:{self._port}") as channel:
multi_callable = channel.unary_stream(_UNARY_STREAM)
call = multi_callable(_REQUEST)
for _ in call:
pass
def stream_unary_call(self):
with grpc.insecure_channel(f"localhost:{self._port}") as channel:
multi_callable = channel.stream_unary(_STREAM_UNARY)
unused_response, call = multi_callable.with_call(
iter([_REQUEST] * STREAM_LENGTH)
)
def stream_stream_call(self):
with grpc.insecure_channel(f"localhost:{self._port}") as channel:
multi_callable = channel.stream_stream(_STREAM_STREAM)
call = multi_callable(iter([_REQUEST] * STREAM_LENGTH))
for _ in call:
pass
def _start_server(self) -> None:
self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
self._server.add_generic_rpc_handlers((_GenericHandler(),))
self._port = self._server.add_insecure_port("[::]:0")
self._server.start()
return self._port
def _validate_metrics(
self, metrics: List[_observability.StatsData]
@ -413,6 +445,41 @@ class ObservabilityTest(unittest.TestCase):
self.assertTrue(prefix_exist)
def unary_unary_call(port, metadata=None):
with grpc.insecure_channel(f"localhost:{port}") as channel:
multi_callable = channel.unary_unary(_UNARY_UNARY)
if metadata:
unused_response, call = multi_callable.with_call(
_REQUEST, metadata=metadata
)
else:
unused_response, call = multi_callable.with_call(_REQUEST)
def unary_stream_call(port):
with grpc.insecure_channel(f"localhost:{port}") as channel:
multi_callable = channel.unary_stream(_UNARY_STREAM)
call = multi_callable(_REQUEST)
for _ in call:
pass
def stream_unary_call(port):
with grpc.insecure_channel(f"localhost:{port}") as channel:
multi_callable = channel.stream_unary(_STREAM_UNARY)
unused_response, call = multi_callable.with_call(
iter([_REQUEST] * STREAM_LENGTH)
)
def stream_stream_call(port):
with grpc.insecure_channel(f"localhost:{port}") as channel:
multi_callable = channel.stream_stream(_STREAM_STREAM)
call = multi_callable(iter([_REQUEST] * STREAM_LENGTH))
for _ in call:
pass
if __name__ == "__main__":
logging.basicConfig()
unittest.main(verbosity=2)

Loading…
Cancel
Save