diff --git a/src/python/grpcio_tests/tests_aio/channelz/channelz_servicer_test.py b/src/python/grpcio_tests/tests_aio/channelz/channelz_servicer_test.py index 8504f05c3d0..8cc0d4d2fcc 100644 --- a/src/python/grpcio_tests/tests_aio/channelz/channelz_servicer_test.py +++ b/src/python/grpcio_tests/tests_aio/channelz/channelz_servicer_test.py @@ -69,40 +69,54 @@ class _GenericHandler(grpc.GenericRpcHandler): return None -class _ChannelServerPair(object): +class _ChannelServerPair: async def start(self): # Server will enable channelz service self.server = aio.server(options=_DISABLE_REUSE_PORT + _ENABLE_CHANNELZ) port = self.server.add_insecure_port('[::]:0') + self.address = 'localhost:%d' % port self.server.add_generic_rpc_handlers((_GenericHandler(),)) await self.server.start() # Channel will enable channelz service... - self.channel = aio.insecure_channel('localhost:%d' % port, + self.channel = aio.insecure_channel(self.address, options=_ENABLE_CHANNELZ) + async def bind_channelz(self, channelz_stub): + resp = await channelz_stub.GetTopChannels( + channelz_pb2.GetTopChannelsRequest(start_channel_id=0)) + for channel in resp.channel: + if channel.data.target == self.address: + self.channel_ref_id = channel.ref.channel_id + + resp = await channelz_stub.GetServers( + channelz_pb2.GetServersRequest(start_server_id=0)) + self.server_ref_id = resp.server[-1].ref.server_id -# Stores channel-server pairs globally, since the memory deallocation is -# non-deterministic in both Core and Python with multiple threads. The -# destroyed Channelz node might still present. So, as a work around, this -# test doesn't close channel-server-pairs between cases. -_pairs = [] + async def stop(self): + await self.channel.close() + await self.server.stop(None) -async def _generate_channel_server_pairs(n): - """Creates channel-server pairs globally, returns their indexes.""" - new_pairs = [_ChannelServerPair() for i in range(n)] - for pair in new_pairs: +async def _create_channel_server_pairs(n, channelz_stub=None): + """Create channel-server pairs.""" + pairs = [_ChannelServerPair() for i in range(n)] + for pair in pairs: await pair.start() - _pairs.extend(new_pairs) - return list(range(len(_pairs) - n, len(_pairs))) + if channelz_stub: + await pair.bind_channelz(channelz_stub) + return pairs + + +async def _destroy_channel_server_pairs(pairs): + for pair in pairs: + await pair.stop() class ChannelzServicerTest(AioTestBase): async def setUp(self): - self._pairs = [] # This server is for Channelz info fetching only # It self should not enable Channelz self._server = aio.server(options=_DISABLE_REUSE_PORT + @@ -118,155 +132,149 @@ class ChannelzServicerTest(AioTestBase): self._channelz_stub = channelz_pb2_grpc.ChannelzStub(self._channel) async def tearDown(self): - await self._server.stop(None) await self._channel.close() + await self._server.stop(None) + + async def _get_server_by_ref_id(self, ref_id): + """Server id may not be consecutive""" + resp = await self._channelz_stub.GetServers( + channelz_pb2.GetServersRequest(start_server_id=ref_id)) + self.assertEqual(ref_id, resp.server[0].ref.server_id) + return resp.server[0] - async def _send_successful_unary_unary(self, idx): - call = _pairs[idx].channel.unary_unary(_SUCCESSFUL_UNARY_UNARY)( - _REQUEST) + async def _send_successful_unary_unary(self, pair): + call = pair.channel.unary_unary(_SUCCESSFUL_UNARY_UNARY)(_REQUEST) self.assertEqual(grpc.StatusCode.OK, await call.code()) - async def _send_failed_unary_unary(self, idx): + async def _send_failed_unary_unary(self, pair): try: - await _pairs[idx].channel.unary_unary(_FAILED_UNARY_UNARY)(_REQUEST) + await pair.channel.unary_unary(_FAILED_UNARY_UNARY)(_REQUEST) except grpc.RpcError: return else: self.fail("This call supposed to fail") - async def _send_successful_stream_stream(self, idx): - call = _pairs[idx].channel.stream_stream(_SUCCESSFUL_STREAM_STREAM)( - iter([_REQUEST] * test_constants.STREAM_LENGTH)) + async def _send_successful_stream_stream(self, pair): + call = pair.channel.stream_stream(_SUCCESSFUL_STREAM_STREAM)(iter( + [_REQUEST] * test_constants.STREAM_LENGTH)) cnt = 0 async for _ in call: cnt += 1 self.assertEqual(cnt, test_constants.STREAM_LENGTH) - async def _get_channel_id(self, idx): - """Channel id may not be consecutive""" - resp = await self._channelz_stub.GetTopChannels( - channelz_pb2.GetTopChannelsRequest(start_channel_id=0)) - self.assertGreater(len(resp.channel), idx) - return resp.channel[idx].ref.channel_id - - async def _get_server_by_id(self, idx): - """Server id may not be consecutive""" - resp = await self._channelz_stub.GetServers( - channelz_pb2.GetServersRequest(start_server_id=0)) - return resp.server[idx] - - async def test_get_top_channels_basic(self): - before = await self._channelz_stub.GetTopChannels( - channelz_pb2.GetTopChannelsRequest(start_channel_id=0)) - await _generate_channel_server_pairs(1) - after = await self._channelz_stub.GetTopChannels( - channelz_pb2.GetTopChannelsRequest(start_channel_id=0)) - self.assertEqual(len(after.channel) - len(before.channel), 1) - self.assertEqual(after.end, True) - async def test_get_top_channels_high_start_id(self): - await _generate_channel_server_pairs(1) + pairs = await _create_channel_server_pairs(1) + resp = await self._channelz_stub.GetTopChannels( channelz_pb2.GetTopChannelsRequest( start_channel_id=_LARGE_UNASSIGNED_ID)) self.assertEqual(len(resp.channel), 0) self.assertEqual(resp.end, True) + await _destroy_channel_server_pairs(pairs) + async def test_successful_request(self): - idx = await _generate_channel_server_pairs(1) - await self._send_successful_unary_unary(idx[0]) + pairs = await _create_channel_server_pairs(1, self._channelz_stub) + + await self._send_successful_unary_unary(pairs[0]) resp = await self._channelz_stub.GetChannel( - channelz_pb2.GetChannelRequest( - channel_id=await self._get_channel_id(idx[0]))) + channelz_pb2.GetChannelRequest(channel_id=pairs[0].channel_ref_id)) + self.assertEqual(resp.channel.data.calls_started, 1) self.assertEqual(resp.channel.data.calls_succeeded, 1) self.assertEqual(resp.channel.data.calls_failed, 0) + await _destroy_channel_server_pairs(pairs) + async def test_failed_request(self): - idx = await _generate_channel_server_pairs(1) - await self._send_failed_unary_unary(idx[0]) + pairs = await _create_channel_server_pairs(1, self._channelz_stub) + + await self._send_failed_unary_unary(pairs[0]) resp = await self._channelz_stub.GetChannel( - channelz_pb2.GetChannelRequest( - channel_id=await self._get_channel_id(idx[0]))) + channelz_pb2.GetChannelRequest(channel_id=pairs[0].channel_ref_id)) self.assertEqual(resp.channel.data.calls_started, 1) self.assertEqual(resp.channel.data.calls_succeeded, 0) self.assertEqual(resp.channel.data.calls_failed, 1) + await _destroy_channel_server_pairs(pairs) + async def test_many_requests(self): - idx = await _generate_channel_server_pairs(1) + pairs = await _create_channel_server_pairs(1, self._channelz_stub) + k_success = 7 k_failed = 9 for i in range(k_success): - await self._send_successful_unary_unary(idx[0]) + await self._send_successful_unary_unary(pairs[0]) for i in range(k_failed): - await self._send_failed_unary_unary(idx[0]) + await self._send_failed_unary_unary(pairs[0]) resp = await self._channelz_stub.GetChannel( - channelz_pb2.GetChannelRequest( - channel_id=await self._get_channel_id(idx[0]))) + channelz_pb2.GetChannelRequest(channel_id=pairs[0].channel_ref_id)) self.assertEqual(resp.channel.data.calls_started, k_success + k_failed) self.assertEqual(resp.channel.data.calls_succeeded, k_success) self.assertEqual(resp.channel.data.calls_failed, k_failed) + await _destroy_channel_server_pairs(pairs) + async def test_many_requests_many_channel(self): k_channels = 4 - idx = await _generate_channel_server_pairs(k_channels) + pairs = await _create_channel_server_pairs(k_channels, + self._channelz_stub) k_success = 11 k_failed = 13 for i in range(k_success): - await self._send_successful_unary_unary(idx[0]) - await self._send_successful_unary_unary(idx[2]) + await self._send_successful_unary_unary(pairs[0]) + await self._send_successful_unary_unary(pairs[2]) for i in range(k_failed): - await self._send_failed_unary_unary(idx[1]) - await self._send_failed_unary_unary(idx[2]) + await self._send_failed_unary_unary(pairs[1]) + await self._send_failed_unary_unary(pairs[2]) # The first channel saw only successes resp = await self._channelz_stub.GetChannel( - channelz_pb2.GetChannelRequest( - channel_id=await self._get_channel_id(idx[0]))) + channelz_pb2.GetChannelRequest(channel_id=pairs[0].channel_ref_id)) self.assertEqual(resp.channel.data.calls_started, k_success) self.assertEqual(resp.channel.data.calls_succeeded, k_success) self.assertEqual(resp.channel.data.calls_failed, 0) # The second channel saw only failures resp = await self._channelz_stub.GetChannel( - channelz_pb2.GetChannelRequest( - channel_id=await self._get_channel_id(idx[1]))) + channelz_pb2.GetChannelRequest(channel_id=pairs[1].channel_ref_id)) self.assertEqual(resp.channel.data.calls_started, k_failed) self.assertEqual(resp.channel.data.calls_succeeded, 0) self.assertEqual(resp.channel.data.calls_failed, k_failed) # The third channel saw both successes and failures resp = await self._channelz_stub.GetChannel( - channelz_pb2.GetChannelRequest( - channel_id=await self._get_channel_id(idx[2]))) + channelz_pb2.GetChannelRequest(channel_id=pairs[2].channel_ref_id)) self.assertEqual(resp.channel.data.calls_started, k_success + k_failed) self.assertEqual(resp.channel.data.calls_succeeded, k_success) self.assertEqual(resp.channel.data.calls_failed, k_failed) # The fourth channel saw nothing resp = await self._channelz_stub.GetChannel( - channelz_pb2.GetChannelRequest( - channel_id=await self._get_channel_id(idx[3]))) + channelz_pb2.GetChannelRequest(channel_id=pairs[3].channel_ref_id)) self.assertEqual(resp.channel.data.calls_started, 0) self.assertEqual(resp.channel.data.calls_succeeded, 0) self.assertEqual(resp.channel.data.calls_failed, 0) + await _destroy_channel_server_pairs(pairs) + async def test_many_subchannels(self): k_channels = 4 - idx = await _generate_channel_server_pairs(k_channels) + pairs = await _create_channel_server_pairs(k_channels, + self._channelz_stub) k_success = 17 k_failed = 19 for i in range(k_success): - await self._send_successful_unary_unary(idx[0]) - await self._send_successful_unary_unary(idx[2]) + await self._send_successful_unary_unary(pairs[0]) + await self._send_successful_unary_unary(pairs[2]) for i in range(k_failed): - await self._send_failed_unary_unary(idx[1]) - await self._send_failed_unary_unary(idx[2]) + await self._send_failed_unary_unary(pairs[1]) + await self._send_failed_unary_unary(pairs[2]) for i in range(k_channels): gc_resp = await self._channelz_stub.GetChannel( channelz_pb2.GetChannelRequest( - channel_id=await self._get_channel_id(idx[i]))) + channel_id=pairs[i].channel_ref_id)) # If no call performed in the channel, there shouldn't be any subchannel if gc_resp.channel.data.calls_started == 0: self.assertEqual(len(gc_resp.channel.subchannel_ref), 0) @@ -285,36 +293,42 @@ class ChannelzServicerTest(AioTestBase): self.assertEqual(gc_resp.channel.data.calls_failed, gsc_resp.subchannel.data.calls_failed) + await _destroy_channel_server_pairs(pairs) + async def test_server_call(self): - idx = await _generate_channel_server_pairs(1) + pairs = await _create_channel_server_pairs(1, self._channelz_stub) + k_success = 23 k_failed = 29 for i in range(k_success): - await self._send_successful_unary_unary(idx[0]) + await self._send_successful_unary_unary(pairs[0]) for i in range(k_failed): - await self._send_failed_unary_unary(idx[0]) + await self._send_failed_unary_unary(pairs[0]) - resp = await self._get_server_by_id(idx[0]) + resp = await self._get_server_by_ref_id(pairs[0].server_ref_id) self.assertEqual(resp.data.calls_started, k_success + k_failed) self.assertEqual(resp.data.calls_succeeded, k_success) self.assertEqual(resp.data.calls_failed, k_failed) + await _destroy_channel_server_pairs(pairs) + async def test_many_subchannels_and_sockets(self): k_channels = 4 - idx = await _generate_channel_server_pairs(k_channels) + pairs = await _create_channel_server_pairs(k_channels, + self._channelz_stub) k_success = 3 k_failed = 5 for i in range(k_success): - await self._send_successful_unary_unary(idx[0]) - await self._send_successful_unary_unary(idx[2]) + await self._send_successful_unary_unary(pairs[0]) + await self._send_successful_unary_unary(pairs[2]) for i in range(k_failed): - await self._send_failed_unary_unary(idx[1]) - await self._send_failed_unary_unary(idx[2]) + await self._send_failed_unary_unary(pairs[1]) + await self._send_failed_unary_unary(pairs[2]) for i in range(k_channels): gc_resp = await self._channelz_stub.GetChannel( channelz_pb2.GetChannelRequest( - channel_id=await self._get_channel_id(idx[i]))) + channel_id=pairs[i].channel_ref_id)) # If no call performed in the channel, there shouldn't be any subchannel if gc_resp.channel.data.calls_started == 0: @@ -340,15 +354,16 @@ class ChannelzServicerTest(AioTestBase): self.assertEqual(gsc_resp.subchannel.data.calls_started, gs_resp.socket.data.messages_sent) + await _destroy_channel_server_pairs(pairs) + async def test_streaming_rpc(self): - idx = await _generate_channel_server_pairs(1) + pairs = await _create_channel_server_pairs(1, self._channelz_stub) # In C++, the argument for _send_successful_stream_stream is message length. # Here the argument is still channel idx, to be consistent with the other two. - await self._send_successful_stream_stream(idx[0]) + await self._send_successful_stream_stream(pairs[0]) gc_resp = await self._channelz_stub.GetChannel( - channelz_pb2.GetChannelRequest( - channel_id=await self._get_channel_id(idx[0]))) + channelz_pb2.GetChannelRequest(channel_id=pairs[0].channel_ref_id)) self.assertEqual(gc_resp.channel.data.calls_started, 1) self.assertEqual(gc_resp.channel.data.calls_succeeded, 1) self.assertEqual(gc_resp.channel.data.calls_failed, 0) @@ -375,12 +390,15 @@ class ChannelzServicerTest(AioTestBase): self.assertEqual(gs_resp.socket.data.messages_received, test_constants.STREAM_LENGTH) + await _destroy_channel_server_pairs(pairs) + async def test_server_sockets(self): - idx = await _generate_channel_server_pairs(1) - await self._send_successful_unary_unary(idx[0]) - await self._send_failed_unary_unary(idx[0]) + pairs = await _create_channel_server_pairs(1, self._channelz_stub) + + await self._send_successful_unary_unary(pairs[0]) + await self._send_failed_unary_unary(pairs[0]) - resp = await self._get_server_by_id(idx[0]) + resp = await self._get_server_by_ref_id(pairs[0].server_ref_id) self.assertEqual(resp.data.calls_started, 2) self.assertEqual(resp.data.calls_succeeded, 1) self.assertEqual(resp.data.calls_failed, 1) @@ -390,11 +408,12 @@ class ChannelzServicerTest(AioTestBase): start_socket_id=0)) # If the RPC call failed, it will raise a grpc.RpcError # So, if there is no exception raised, considered pass + await _destroy_channel_server_pairs(pairs) async def test_server_listen_sockets(self): - idx = await _generate_channel_server_pairs(1) + pairs = await _create_channel_server_pairs(1, self._channelz_stub) - resp = await self._get_server_by_id(idx[0]) + resp = await self._get_server_by_ref_id(pairs[0].server_ref_id) self.assertEqual(len(resp.listen_socket), 1) gs_resp = await self._channelz_stub.GetSocket( @@ -402,6 +421,7 @@ class ChannelzServicerTest(AioTestBase): socket_id=resp.listen_socket[0].socket_id)) # If the RPC call failed, it will raise a grpc.RpcError # So, if there is no exception raised, considered pass + await _destroy_channel_server_pairs(pairs) async def test_invalid_query_get_server(self): with self.assertRaises(aio.AioRpcError) as exception_context: