Merge branch 'master' into failhijackedrecv

pull/17179/head
Yash Tibrewal 6 years ago
commit 059459a9ee
  1. 23
      include/grpcpp/impl/codegen/call_op_set.h
  2. 11
      include/grpcpp/impl/codegen/interceptor.h
  3. 27
      include/grpcpp/impl/codegen/interceptor_common.h
  4. 4
      src/python/grpcio_tests/tests/unit/_cython/_fork_test.py
  5. 104
      src/python/grpcio_tests/tests/unit/_logging_test.py
  6. 176
      test/cpp/end2end/client_interceptors_end2end_test.cc
  7. 10
      test/cpp/end2end/interceptors_util.cc
  8. 3
      test/cpp/end2end/interceptors_util.h

@ -326,21 +326,37 @@ class CallOpSendMessage {
// Flags are per-message: clear them after use.
write_options_.Clear();
}
void FinishOp(bool* status) { send_buf_.Clear(); }
void FinishOp(bool* status) {
if (!send_buf_.Valid()) {
return;
}
if (hijacked_ && failed_send_) {
// Hijacking interceptor failed this Op
*status = false;
} else if (!*status) {
// This Op was passed down to core and the Op failed
failed_send_ = true;
}
}
void SetInterceptionHookPoint(
InterceptorBatchMethodsImpl* interceptor_methods) {
if (!send_buf_.Valid()) return;
interceptor_methods->AddInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_SEND_MESSAGE);
interceptor_methods->SetSendMessage(&send_buf_, msg_);
interceptor_methods->SetSendMessage(&send_buf_, msg_, &failed_send_);
}
void SetFinishInterceptionHookPoint(
InterceptorBatchMethodsImpl* interceptor_methods) {
if (send_buf_.Valid()) {
interceptor_methods->AddInterceptionHookPoint(
experimental::InterceptionHookPoints::POST_SEND_MESSAGE);
}
send_buf_.Clear();
// The contents of the SendMessage value that was previously set
// has had its references stolen by core's operations
interceptor_methods->SetSendMessage(nullptr, nullptr);
interceptor_methods->SetSendMessage(nullptr, nullptr, &failed_send_);
}
void SetHijackingState(InterceptorBatchMethodsImpl* interceptor_methods) {
@ -350,6 +366,7 @@ class CallOpSendMessage {
private:
const void* msg_ = nullptr; // The original non-serialized message
bool hijacked_ = false;
bool failed_send_ = false;
ByteBuffer send_buf_;
WriteOptions write_options_;
};

@ -46,9 +46,10 @@ namespace experimental {
/// operation has been requested and it is available. POST_RECV means that a
/// result is available but has not yet been passed back to the application.
enum class InterceptionHookPoints {
/// The first two in this list are for clients and servers
/// The first three in this list are for clients and servers
PRE_SEND_INITIAL_METADATA,
PRE_SEND_MESSAGE,
POST_SEND_MESSAGE,
PRE_SEND_STATUS, // server only
PRE_SEND_CLOSE, // client only: WritesDone for stream; after write in unary
/// The following three are for hijacked clients only and can only be
@ -117,6 +118,10 @@ class InterceptorBatchMethods {
/// only supported for sync and callback APIs at the present moment.
virtual const void* GetSendMessage() = 0;
/// Checks whether the SEND MESSAGE op succeeded. Valid for POST_SEND_MESSAGE
/// interceptions.
virtual bool GetSendMessageStatus() = 0;
/// Returns a modifiable multimap of the initial metadata to be sent. Valid
/// for PRE_SEND_INITIAL_METADATA interceptions. A value of nullptr indicates
/// that this field is not valid.
@ -167,6 +172,10 @@ class InterceptorBatchMethods {
/// op. This would be a signal to the reader that there will be no more
/// messages, or the stream has failed or been cancelled.
virtual void FailHijackedRecvMessage() = 0;
/// On a hijacked RPC/ to-be hijacked RPC, this can be called to fail a SEND
/// MESSAGE op
virtual void FailHijackedSendMessage() = 0;
};
/// Interface for an interceptor. Interceptor authors must create a class

@ -83,6 +83,8 @@ class InterceptorBatchMethodsImpl
const void* GetSendMessage() override { return orig_send_message_; }
bool GetSendMessageStatus() override { return !*fail_send_message_; }
std::multimap<grpc::string, grpc::string>* GetSendInitialMetadata() override {
return send_initial_metadata_;
}
@ -112,14 +114,22 @@ class InterceptorBatchMethodsImpl
Status* GetRecvStatus() override { return recv_status_; }
void FailHijackedSendMessage() override {
GPR_CODEGEN_ASSERT(hooks_[static_cast<size_t>(
experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)]);
*fail_send_message_ = true;
}
std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvTrailingMetadata()
override {
return recv_trailing_metadata_->map();
}
void SetSendMessage(ByteBuffer* buf, const void* msg) {
void SetSendMessage(ByteBuffer* buf, const void* msg,
bool* fail_send_message) {
send_message_ = buf;
orig_send_message_ = msg;
fail_send_message_ = fail_send_message;
}
void SetSendInitialMetadata(
@ -348,6 +358,7 @@ class InterceptorBatchMethodsImpl
std::function<void(void)> callback_;
ByteBuffer* send_message_ = nullptr;
bool* fail_send_message_ = nullptr;
const void* orig_send_message_ = nullptr;
std::multimap<grpc::string, grpc::string>* send_initial_metadata_;
@ -402,6 +413,14 @@ class CancelInterceptorBatchMethods
return nullptr;
}
bool GetSendMessageStatus() override {
GPR_CODEGEN_ASSERT(
false &&
"It is illegal to call GetSendMessageStatus on a method which "
"has a Cancel notification");
return false;
}
const void* GetSendMessage() override {
GPR_CODEGEN_ASSERT(
false &&
@ -481,6 +500,12 @@ class CancelInterceptorBatchMethods
"It is illegal to call FailHijackedRecvMessage on a "
"method which has a Cancel notification");
}
void FailHijackedSendMessage() override {
GPR_CODEGEN_ASSERT(false &&
"It is illegal to call FailHijackedSendMessage on a "
"method which has a Cancel notification");
}
};
} // namespace internal
} // namespace grpc

@ -27,6 +27,7 @@ def _get_number_active_threads():
class ForkPosixTester(unittest.TestCase):
def setUp(self):
self._saved_fork_support_flag = cygrpc._GRPC_ENABLE_FORK_SUPPORT
cygrpc._GRPC_ENABLE_FORK_SUPPORT = True
def testForkManagedThread(self):
@ -50,6 +51,9 @@ class ForkPosixTester(unittest.TestCase):
thread.join()
self.assertEqual(0, _get_number_active_threads())
def tearDown(self):
cygrpc._GRPC_ENABLE_FORK_SUPPORT = self._saved_fork_support_flag
@unittest.skipUnless(os.name == 'nt', 'Windows-specific tests')
class ForkWindowsTester(unittest.TestCase):

@ -14,66 +14,86 @@
"""Test of gRPC Python's interaction with the python logging module"""
import unittest
import six
from six.moves import reload_module
import logging
import grpc
import functools
import subprocess
import sys
INTERPRETER = sys.executable
def patch_stderr(f):
@functools.wraps(f)
def _impl(*args, **kwargs):
old_stderr = sys.stderr
sys.stderr = six.StringIO()
try:
f(*args, **kwargs)
finally:
sys.stderr = old_stderr
class LoggingTest(unittest.TestCase):
return _impl
def test_logger_not_occupied(self):
script = """if True:
import logging
import grpc
def isolated_logging(f):
if len(logging.getLogger().handlers) != 0:
raise Exception('expected 0 logging handlers')
@functools.wraps(f)
def _impl(*args, **kwargs):
reload_module(logging)
reload_module(grpc)
try:
f(*args, **kwargs)
finally:
reload_module(logging)
"""
self._verifyScriptSucceeds(script)
return _impl
def test_handler_found(self):
script = """if True:
import logging
import grpc
"""
out, err = self._verifyScriptSucceeds(script)
self.assertEqual(0, len(err), 'unexpected output to stderr')
class LoggingTest(unittest.TestCase):
def test_can_configure_logger(self):
script = """if True:
import logging
import six
@isolated_logging
def test_logger_not_occupied(self):
self.assertEqual(0, len(logging.getLogger().handlers))
import grpc
@patch_stderr
@isolated_logging
def test_handler_found(self):
self.assertEqual(0, len(sys.stderr.getvalue()))
@isolated_logging
def test_can_configure_logger(self):
intended_stream = six.StringIO()
logging.basicConfig(stream=intended_stream)
self.assertEqual(1, len(logging.getLogger().handlers))
self.assertIs(logging.getLogger().handlers[0].stream, intended_stream)
intended_stream = six.StringIO()
logging.basicConfig(stream=intended_stream)
if len(logging.getLogger().handlers) != 1:
raise Exception('expected 1 logging handler')
if logging.getLogger().handlers[0].stream is not intended_stream:
raise Exception('wrong handler stream')
"""
self._verifyScriptSucceeds(script)
@isolated_logging
def test_grpc_logger(self):
self.assertIn("grpc", logging.Logger.manager.loggerDict)
root_logger = logging.getLogger("grpc")
self.assertEqual(1, len(root_logger.handlers))
self.assertIsInstance(root_logger.handlers[0], logging.NullHandler)
script = """if True:
import logging
import grpc
if "grpc" not in logging.Logger.manager.loggerDict:
raise Exception('grpc logger not found')
root_logger = logging.getLogger("grpc")
if len(root_logger.handlers) != 1:
raise Exception('expected 1 root logger handler')
if not isinstance(root_logger.handlers[0], logging.NullHandler):
raise Exception('expected logging.NullHandler')
"""
self._verifyScriptSucceeds(script)
def _verifyScriptSucceeds(self, script):
process = subprocess.Popen(
[INTERPRETER, '-c', script],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
out, err = process.communicate()
self.assertEqual(
0, process.returncode,
'process failed with exit code %d (stdout: %s, stderr: %s)' %
(process.returncode, out, err))
return out, err
if __name__ == '__main__':

@ -270,6 +270,129 @@ class HijackingInterceptorMakesAnotherCallFactory
}
};
class BidiStreamingRpcHijackingInterceptor : public experimental::Interceptor {
public:
BidiStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) {
info_ = info;
}
virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
bool hijack = false;
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
CheckMetadata(*methods->GetSendInitialMetadata(), "testkey", "testvalue");
hijack = true;
}
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
EchoRequest req;
auto* buffer = methods->GetSerializedSendMessage();
auto copied_buffer = *buffer;
EXPECT_TRUE(
SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
.ok());
EXPECT_EQ(req.message().find("Hello"), 0u);
msg = req.message();
}
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
// Got nothing to do here for now
}
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
CheckMetadata(*methods->GetRecvTrailingMetadata(), "testkey",
"testvalue");
auto* status = methods->GetRecvStatus();
EXPECT_EQ(status->ok(), true);
}
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
EchoResponse* resp =
static_cast<EchoResponse*>(methods->GetRecvMessage());
resp->set_message(msg);
}
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
EXPECT_EQ(static_cast<EchoResponse*>(methods->GetRecvMessage())
->message()
.find("Hello"),
0u);
}
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
auto* map = methods->GetRecvTrailingMetadata();
// insert the metadata that we want
EXPECT_EQ(map->size(), static_cast<unsigned>(0));
map->insert(std::make_pair("testkey", "testvalue"));
auto* status = methods->GetRecvStatus();
*status = Status(StatusCode::OK, "");
}
if (hijack) {
methods->Hijack();
} else {
methods->Proceed();
}
}
private:
experimental::ClientRpcInfo* info_;
grpc::string msg;
};
class ClientStreamingRpcHijackingInterceptor
: public experimental::Interceptor {
public:
ClientStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) {
info_ = info;
}
virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
bool hijack = false;
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
hijack = true;
}
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
if (++count_ > 10) {
methods->FailHijackedSendMessage();
}
}
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::POST_SEND_MESSAGE)) {
EXPECT_FALSE(got_failed_send_);
got_failed_send_ = !methods->GetSendMessageStatus();
}
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
auto* status = methods->GetRecvStatus();
*status = Status(StatusCode::UNAVAILABLE, "Done sending 10 messages");
}
if (hijack) {
methods->Hijack();
} else {
methods->Proceed();
}
}
static bool GotFailedSend() { return got_failed_send_; }
private:
experimental::ClientRpcInfo* info_;
int count_ = 0;
static bool got_failed_send_;
};
bool ClientStreamingRpcHijackingInterceptor::got_failed_send_ = false;
class ClientStreamingRpcHijackingInterceptorFactory
: public experimental::ClientInterceptorFactoryInterface {
public:
virtual experimental::Interceptor* CreateClientInterceptor(
experimental::ClientRpcInfo* info) override {
return new ClientStreamingRpcHijackingInterceptor(info);
}
};
class ServerStreamingRpcHijackingInterceptor
: public experimental::Interceptor {
public:
@ -292,7 +415,7 @@ class ServerStreamingRpcHijackingInterceptor
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
EchoRequest req;
auto* buffer = methods->GetSendMessage();
auto* buffer = methods->GetSerializedSendMessage();
auto copied_buffer = *buffer;
EXPECT_TRUE(
SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
@ -367,6 +490,15 @@ class ServerStreamingRpcHijackingInterceptorFactory
}
};
class BidiStreamingRpcHijackingInterceptorFactory
: public experimental::ClientInterceptorFactoryInterface {
public:
virtual experimental::Interceptor* CreateClientInterceptor(
experimental::ClientRpcInfo* info) override {
return new BidiStreamingRpcHijackingInterceptor(info);
}
};
class LoggingInterceptor : public experimental::Interceptor {
public:
LoggingInterceptor(experimental::ClientRpcInfo* info) { info_ = info; }
@ -647,6 +779,35 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) {
EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
}
TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingHijackingTest) {
ChannelArguments args;
std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
creators;
creators.push_back(
std::unique_ptr<ClientStreamingRpcHijackingInterceptorFactory>(
new ClientStreamingRpcHijackingInterceptorFactory()));
auto channel = experimental::CreateCustomChannelWithInterceptors(
server_address_, InsecureChannelCredentials(), args, std::move(creators));
auto stub = grpc::testing::EchoTestService::NewStub(channel);
ClientContext ctx;
EchoRequest req;
EchoResponse resp;
req.mutable_param()->set_echo_metadata(true);
req.set_message("Hello");
string expected_resp = "";
auto writer = stub->RequestStream(&ctx, &resp);
for (int i = 0; i < 10; i++) {
EXPECT_TRUE(writer->Write(req));
expected_resp += "Hello";
}
// The interceptor will reject the 11th message
writer->Write(req);
Status s = writer->Finish();
EXPECT_EQ(s.ok(), false);
EXPECT_TRUE(ClientStreamingRpcHijackingInterceptor::GotFailedSend());
}
TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingHijackingTest) {
ChannelArguments args;
DummyInterceptor::Reset();
@ -661,6 +822,19 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingHijackingTest) {
EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage());
}
TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingHijackingTest) {
ChannelArguments args;
DummyInterceptor::Reset();
std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
creators;
creators.push_back(
std::unique_ptr<BidiStreamingRpcHijackingInterceptorFactory>(
new BidiStreamingRpcHijackingInterceptorFactory()));
auto channel = experimental::CreateCustomChannelWithInterceptors(
server_address_, InsecureChannelCredentials(), args, std::move(creators));
MakeBidiStreamingCall(channel);
}
TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) {
ChannelArguments args;
DummyInterceptor::Reset();

@ -132,6 +132,16 @@ bool CheckMetadata(const std::multimap<grpc::string_ref, grpc::string_ref>& map,
return false;
}
bool CheckMetadata(const std::multimap<grpc::string, grpc::string>& map,
const string& key, const string& value) {
for (const auto& pair : map) {
if (pair.first == key && pair.second == value) {
return true;
}
}
return false;
}
std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
CreateDummyClientInterceptors() {
std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>

@ -165,6 +165,9 @@ void MakeCallbackCall(const std::shared_ptr<Channel>& channel);
bool CheckMetadata(const std::multimap<grpc::string_ref, grpc::string_ref>& map,
const string& key, const string& value);
bool CheckMetadata(const std::multimap<grpc::string, grpc::string>& map,
const string& key, const string& value);
std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
CreateDummyClientInterceptors();

Loading…
Cancel
Save