mirror of https://github.com/grpc/grpc.git
parent
10e39e316c
commit
5afd77398e
16 changed files with 893 additions and 195 deletions
@ -0,0 +1,55 @@ |
||||
# Copyright 2019 The gRPC authors. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
from grpc._cython import cygrpc |
||||
|
||||
NoCompression = cygrpc.CompressionAlgorithm.none |
||||
Deflate = cygrpc.CompressionAlgorithm.deflate |
||||
Gzip = cygrpc.CompressionAlgorithm.gzip |
||||
|
||||
_METADATA_STRING_MAPPING = { |
||||
NoCompression: 'identity', |
||||
Deflate: 'deflate', |
||||
Gzip: 'gzip', |
||||
} |
||||
|
||||
|
||||
def _compression_algorithm_to_metadata_value(compression): |
||||
return _METADATA_STRING_MAPPING[compression] |
||||
|
||||
|
||||
def compression_algorithm_to_metadata(compression): |
||||
return (cygrpc.GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY, |
||||
_compression_algorithm_to_metadata_value(compression)) |
||||
|
||||
|
||||
def create_channel_option(compression): |
||||
return ((cygrpc.GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM, |
||||
int(compression)),) if compression else () |
||||
|
||||
|
||||
def augment_metadata(metadata, compression): |
||||
if not metadata and not compression: |
||||
return None |
||||
base_metadata = tuple(metadata) if metadata else () |
||||
compression_metadata = ( |
||||
compression_algorithm_to_metadata(compression),) if compression else () |
||||
return base_metadata + compression_metadata |
||||
|
||||
|
||||
__all__ = ( |
||||
"NoCompression", |
||||
"Deflate", |
||||
"Gzip", |
||||
) |
@ -0,0 +1,164 @@ |
||||
# Copyright 2019 the gRPC authors. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
""" Proxies a TCP connection between a single client-server pair. |
||||
|
||||
This proxy is not suitable for production, but should work well for cases in |
||||
which a test needs to spy on the bytes put on the wire between a server and |
||||
a client. |
||||
""" |
||||
|
||||
from __future__ import absolute_import |
||||
from __future__ import division |
||||
from __future__ import print_function |
||||
|
||||
import datetime |
||||
import select |
||||
import socket |
||||
import threading |
||||
|
||||
_TCP_PROXY_BUFFER_SIZE = 1024 |
||||
_TCP_PROXY_TIMEOUT = datetime.timedelta(milliseconds=500) |
||||
|
||||
|
||||
def _create_socket_ipv6(bind_address): |
||||
listen_socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) |
||||
listen_socket.bind((bind_address, 0, 0, 0)) |
||||
return listen_socket |
||||
|
||||
|
||||
def _create_socket_ipv4(bind_address): |
||||
listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
||||
listen_socket.bind((bind_address, 0)) |
||||
return listen_socket |
||||
|
||||
|
||||
def _init_listen_socket(bind_address): |
||||
listen_socket = None |
||||
if socket.has_ipv6: |
||||
try: |
||||
listen_socket = _create_socket_ipv6(bind_address) |
||||
except socket.error: |
||||
listen_socket = _create_socket_ipv4(bind_address) |
||||
else: |
||||
listen_socket = _create_socket_ipv4(bind_address) |
||||
listen_socket.listen(1) |
||||
return listen_socket, listen_socket.getsockname()[1] |
||||
|
||||
|
||||
def _init_proxy_socket(gateway_address, gateway_port): |
||||
proxy_socket = socket.create_connection((gateway_address, gateway_port)) |
||||
return proxy_socket |
||||
|
||||
|
||||
class TcpProxy(object): |
||||
"""Proxies a TCP connection between one client and one server.""" |
||||
|
||||
def __init__(self, bind_address, gateway_address, gateway_port): |
||||
self._bind_address = bind_address |
||||
self._gateway_address = gateway_address |
||||
self._gateway_port = gateway_port |
||||
|
||||
self._byte_count_lock = threading.RLock() |
||||
self._sent_byte_count = 0 |
||||
self._received_byte_count = 0 |
||||
|
||||
self._stop_event = threading.Event() |
||||
|
||||
self._port = None |
||||
self._listen_socket = None |
||||
self._proxy_socket = None |
||||
|
||||
# The following three attributes are owned by the serving thread. |
||||
self._northbound_data = b"" |
||||
self._southbound_data = b"" |
||||
self._client_sockets = [] |
||||
|
||||
self._thread = threading.Thread(target=self._run_proxy) |
||||
|
||||
def start(self): |
||||
self._listen_socket, self._port = _init_listen_socket( |
||||
self._bind_address) |
||||
self._proxy_socket = _init_proxy_socket(self._gateway_address, |
||||
self._gateway_port) |
||||
self._thread.start() |
||||
|
||||
def get_port(self): |
||||
return self._port |
||||
|
||||
def _handle_reads(self, sockets_to_read): |
||||
for socket_to_read in sockets_to_read: |
||||
if socket_to_read is self._listen_socket: |
||||
client_socket, client_address = socket_to_read.accept() |
||||
self._client_sockets.append(client_socket) |
||||
elif socket_to_read is self._proxy_socket: |
||||
data = socket_to_read.recv(_TCP_PROXY_BUFFER_SIZE) |
||||
with self._byte_count_lock: |
||||
self._received_byte_count += len(data) |
||||
self._northbound_data += data |
||||
elif socket_to_read in self._client_sockets: |
||||
data = socket_to_read.recv(_TCP_PROXY_BUFFER_SIZE) |
||||
if data: |
||||
with self._byte_count_lock: |
||||
self._sent_byte_count += len(data) |
||||
self._southbound_data += data |
||||
else: |
||||
self._client_sockets.remove(socket_to_read) |
||||
else: |
||||
raise RuntimeError('Unidentified socket appeared in read set.') |
||||
|
||||
def _handle_writes(self, sockets_to_write): |
||||
for socket_to_write in sockets_to_write: |
||||
if socket_to_write is self._proxy_socket: |
||||
if self._southbound_data: |
||||
self._proxy_socket.sendall(self._southbound_data) |
||||
self._southbound_data = b"" |
||||
elif socket_to_write in self._client_sockets: |
||||
if self._northbound_data: |
||||
socket_to_write.sendall(self._northbound_data) |
||||
self._northbound_data = b"" |
||||
|
||||
def _run_proxy(self): |
||||
while not self._stop_event.is_set(): |
||||
expected_reads = (self._listen_socket, self._proxy_socket) + tuple( |
||||
self._client_sockets) |
||||
expected_writes = expected_reads |
||||
sockets_to_read, sockets_to_write, _ = select.select( |
||||
expected_reads, expected_writes, (), |
||||
_TCP_PROXY_TIMEOUT.total_seconds()) |
||||
self._handle_reads(sockets_to_read) |
||||
self._handle_writes(sockets_to_write) |
||||
for client_socket in self._client_sockets: |
||||
client_socket.close() |
||||
|
||||
def stop(self): |
||||
self._stop_event.set() |
||||
self._thread.join() |
||||
self._listen_socket.close() |
||||
self._proxy_socket.close() |
||||
|
||||
def get_byte_count(self): |
||||
with self._byte_count_lock: |
||||
return self._sent_byte_count, self._received_byte_count |
||||
|
||||
def reset_byte_count(self): |
||||
with self._byte_count_lock: |
||||
self._byte_count = 0 |
||||
self._received_byte_count = 0 |
||||
|
||||
def __enter__(self): |
||||
self.start() |
||||
return self |
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb): |
||||
self.stop() |
Loading…
Reference in new issue