Rewrite ProtoFileParser

pull/7735/head
Yuchen Zeng 9 years ago
parent 68ca351126
commit 4272cac7ae
  1. 144
      test/cpp/util/grpc_tool.cc
  2. 187
      test/cpp/util/proto_file_parser.cc
  3. 44
      test/cpp/util/proto_file_parser.h

@ -31,7 +31,7 @@
* *
*/ */
#include "test/cpp/util/grpc_tool.h" #include "grpc_tool.h"
#include <unistd.h> #include <unistd.h>
#include <fstream> #include <fstream>
@ -55,15 +55,14 @@
DEFINE_bool(enable_ssl, false, "Whether to use ssl/tls."); DEFINE_bool(enable_ssl, false, "Whether to use ssl/tls.");
DEFINE_bool(use_auth, false, "Whether to create default google credentials."); DEFINE_bool(use_auth, false, "Whether to create default google credentials.");
DEFINE_string(input_binary_file, "", DEFINE_bool(remotedb, true, "Use server types to parse and format messages");
"Path to input file containing serialized request.");
DEFINE_string(output_binary_file, "",
"Path to output file to write serialized response.");
DEFINE_string(metadata, "", DEFINE_string(metadata, "",
"Metadata to send to server, in the form of key1:val1:key2:val2"); "Metadata to send to server, in the form of key1:val1:key2:val2");
DEFINE_string(proto_path, ".", "Path to look for the proto file."); DEFINE_string(proto_path, ".", "Path to look for the proto file.");
// TODO(zyc): support a list of input proto files DEFINE_string(proto_file, "", "Name of the proto file.");
DEFINE_string(protofiles, "", "Name of the proto file."); DEFINE_bool(binary_input, false, "Input in binary format");
DEFINE_bool(binary_output, false, "Output in binary format");
DEFINE_string(infile, "", "Input file (default is stdin)");
namespace grpc { namespace grpc {
namespace testing { namespace testing {
@ -73,8 +72,22 @@ class GrpcTool {
public: public:
explicit GrpcTool(); explicit GrpcTool();
virtual ~GrpcTool() {} virtual ~GrpcTool() {}
bool Help(int argc, const char** argv, GrpcToolOutputCallback callback); bool Help(int argc, const char** argv, GrpcToolOutputCallback callback);
bool CallMethod(int argc, const char** argv, GrpcToolOutputCallback callback); bool CallMethod(int argc, const char** argv, GrpcToolOutputCallback callback);
// TODO(zyc): implement the following methods
// bool ListServices(int argc, const char** argv, GrpcToolOutputCallback
// callback);
// bool PrintType(int argc, const char** argv, GrpcToolOutputCallback
// callback);
// bool PrintTypeId(int argc, const char** argv, GrpcToolOutputCallback
// callback);
// bool ParseMessage(int argc, const char** argv, GrpcToolOutputCallback
// callback);
// bool ToText(int argc, const char** argv, GrpcToolOutputCallback callback);
// bool ToBinary(int argc, const char** argv, GrpcToolOutputCallback
// callback);
void SetPrintCommandMode(int exit_status) { void SetPrintCommandMode(int exit_status) {
print_command_usage_ = true; print_command_usage_ = true;
usage_exit_status_ = exit_status; usage_exit_status_ = exit_status;
@ -82,6 +95,7 @@ class GrpcTool {
private: private:
void CommandUsage(const grpc::string& usage) const; void CommandUsage(const grpc::string& usage) const;
std::shared_ptr<grpc::Channel> NewChannel(const grpc::string& server_address);
bool print_command_usage_; bool print_command_usage_;
int usage_exit_status_; int usage_exit_status_;
}; };
@ -222,6 +236,21 @@ void GrpcTool::CommandUsage(const grpc::string& usage) const {
} }
} }
std::shared_ptr<grpc::Channel> GrpcTool::NewChannel(
const grpc::string& server_address) {
std::shared_ptr<grpc::ChannelCredentials> creds;
if (!FLAGS_enable_ssl) {
creds = grpc::InsecureChannelCredentials();
} else {
if (FLAGS_use_auth) {
creds = grpc::GoogleDefaultCredentials();
} else {
creds = grpc::SslCredentials(grpc::SslCredentialsOptions());
}
}
return grpc::CreateChannel(server_address, creds);
}
bool GrpcTool::Help(int argc, const char** argv, bool GrpcTool::Help(int argc, const char** argv,
GrpcToolOutputCallback callback) { GrpcToolOutputCallback callback) {
CommandUsage( CommandUsage(
@ -250,17 +279,18 @@ bool GrpcTool::CallMethod(int argc, const char** argv,
" <service> ; Exported service name\n" " <service> ; Exported service name\n"
" <method> ; Method name\n" " <method> ; Method name\n"
" <request> ; Text protobuffer (overrides infile)\n" " <request> ; Text protobuffer (overrides infile)\n"
" --protofiles ; Comma separated proto files used as a" " --proto_file ; Comma separated proto files used as a"
" fallback when parsing request/response\n" " fallback when parsing request/response\n"
" --proto_path ; The search path of proto files, valid" " --proto_path ; The search path of proto files, valid"
" only when --protofiles is given\n" " only when --proto_file is given\n"
" --metadata ; The metadata to be sent to the server\n" " --metadata ; The metadata to be sent to the server\n"
" --enable_ssl ; Set whether to use tls\n" " --enable_ssl ; Set whether to use tls\n"
" --use_auth ; Set whether to create default google" " --use_auth ; Set whether to create default google"
" credentials\n" " credentials\n"
" --infile ; Input filename (defaults to stdin)\n"
" --outfile ; Output filename (defaults to stdout)\n" " --outfile ; Output filename (defaults to stdout)\n"
" --input_binary_file ; Path to input file in binary format\n" " --binary_input ; Input in binary format\n"
" --binary_output ; Path to output file in binary format\n"); " --binary_output ; Output in binary format\n");
std::stringstream output_ss; std::stringstream output_ss;
grpc::string request_text; grpc::string request_text;
@ -271,63 +301,44 @@ bool GrpcTool::CallMethod(int argc, const char** argv,
if (argc == 3) { if (argc == 3) {
request_text = argv[2]; request_text = argv[2];
} if (!FLAGS_infile.empty()) {
fprintf(stderr, "warning: request given in argv, ignoring --infile\n");
std::shared_ptr<grpc::ChannelCredentials> creds; }
if (!FLAGS_enable_ssl) {
creds = grpc::InsecureChannelCredentials();
} else { } else {
if (FLAGS_use_auth) { std::stringstream input_stream;
creds = grpc::GoogleDefaultCredentials(); if (FLAGS_infile.empty()) {
if (isatty(STDIN_FILENO)) {
fprintf(stderr, "reading request message from stdin...\n");
}
input_stream << std::cin.rdbuf();
} else { } else {
creds = grpc::SslCredentials(grpc::SslCredentialsOptions()); std::ifstream input_file(FLAGS_infile, std::ios::in | std::ios::binary);
} input_stream << input_file.rdbuf();
} input_file.close();
std::shared_ptr<grpc::Channel> channel =
grpc::CreateChannel(server_address, creds);
if (request_text.empty() && FLAGS_input_binary_file.empty()) {
if (isatty(STDIN_FILENO)) {
std::cout << "reading request message from stdin..." << std::endl;
} }
std::stringstream input_stream;
input_stream << std::cin.rdbuf();
request_text = input_stream.str(); request_text = input_stream.str();
} }
if (!request_text.empty()) { std::shared_ptr<grpc::Channel> channel = NewChannel(server_address);
if (!FLAGS_protofiles.empty()) { if (!FLAGS_binary_input || !FLAGS_binary_output) {
parser.reset(new grpc::testing::ProtoFileParser( parser.reset(
FLAGS_proto_path, FLAGS_protofiles, method_name)); new grpc::testing::ProtoFileParser(FLAGS_remotedb ? channel : nullptr,
} else { FLAGS_proto_path, FLAGS_proto_file));
parser.reset(new grpc::testing::ProtoFileParser(channel, method_name));
}
method_name = parser->GetFullMethodName();
if (parser->HasError()) { if (parser->HasError()) {
return 1; return false;
}
if (!FLAGS_input_binary_file.empty()) {
std::cout
<< "warning: request given in argv, ignoring --input_binary_file"
<< std::endl;
} }
} }
if (parser) { if (FLAGS_binary_input) {
serialized_request_proto = serialized_request_proto = request_text;
parser->GetSerializedProto(request_text, true /* is_request */); } else {
serialized_request_proto = parser->GetSerializedProtoFromMethod(
method_name, request_text, true /* is_request */);
if (parser->HasError()) { if (parser->HasError()) {
return 1; return false;
} }
} else if (!FLAGS_input_binary_file.empty()) {
std::ifstream input_file(FLAGS_input_binary_file,
std::ios::in | std::ios::binary);
std::stringstream input_stream;
input_stream << input_file.rdbuf();
serialized_request_proto = input_stream.str();
} }
std::cout << "connecting to " << server_address << std::endl; std::cerr << "connecting to " << server_address << std::endl;
grpc::string serialized_response_proto; grpc::string serialized_response_proto;
std::multimap<grpc::string, grpc::string> client_metadata; std::multimap<grpc::string, grpc::string> client_metadata;
@ -336,30 +347,27 @@ bool GrpcTool::CallMethod(int argc, const char** argv,
ParseMetadataFlag(&client_metadata); ParseMetadataFlag(&client_metadata);
PrintMetadata(client_metadata, "Sending client initial metadata:"); PrintMetadata(client_metadata, "Sending client initial metadata:");
grpc::Status s = grpc::testing::CliCall::Call( grpc::Status s = grpc::testing::CliCall::Call(
channel, method_name, serialized_request_proto, channel, parser->GetFormatedMethodName(method_name),
&serialized_response_proto, client_metadata, &server_initial_metadata, serialized_request_proto, &serialized_response_proto, client_metadata,
&server_trailing_metadata); &server_initial_metadata, &server_trailing_metadata);
PrintMetadata(server_initial_metadata, PrintMetadata(server_initial_metadata,
"Received initial metadata from server:"); "Received initial metadata from server:");
PrintMetadata(server_trailing_metadata, PrintMetadata(server_trailing_metadata,
"Received trailing metadata from server:"); "Received trailing metadata from server:");
if (s.ok()) { if (s.ok()) {
std::cout << "Rpc succeeded with OK status" << std::endl; std::cerr << "Rpc succeeded with OK status" << std::endl;
if (parser) { if (FLAGS_binary_output) {
grpc::string response_text = parser->GetTextFormat( output_ss << serialized_response_proto;
serialized_response_proto, false /* is_request */); } else {
grpc::string response_text = parser->GetTextFormatFromMethod(
method_name, serialized_response_proto, false /* is_request */);
if (parser->HasError()) { if (parser->HasError()) {
return false; return false;
} }
output_ss << "Response: \n " << response_text << std::endl; output_ss << "Response: \n " << response_text << std::endl;
} }
if (!FLAGS_output_binary_file.empty()) {
std::ofstream output_file(FLAGS_output_binary_file,
std::ios::trunc | std::ios::binary);
output_file << serialized_response_proto;
}
} else { } else {
std::cout << "Rpc failed with status code " << s.error_code() std::cerr << "Rpc failed with status code " << s.error_code()
<< ", error message: " << s.error_message() << std::endl; << ", error message: " << s.error_message() << std::endl;
} }

@ -71,7 +71,7 @@ class ErrorPrinter
void AddWarning(const grpc::string& filename, int line, int column, void AddWarning(const grpc::string& filename, int line, int column,
const grpc::string& message) GRPC_OVERRIDE { const grpc::string& message) GRPC_OVERRIDE {
std::cout << "warning " << filename << " " << line << " " << column << " " std::cerr << "warning " << filename << " " << line << " " << column << " "
<< message << std::endl; << message << std::endl;
} }
@ -79,62 +79,72 @@ class ErrorPrinter
ProtoFileParser* parser_; // not owned ProtoFileParser* parser_; // not owned
}; };
ProtoFileParser::ProtoFileParser(const grpc::string& proto_path, ProtoFileParser::ProtoFileParser(std::shared_ptr<grpc::Channel> channel,
const grpc::string& file_name, const grpc::string& proto_path,
const grpc::string& method) const grpc::string& protofiles)
: has_error_(false) { : has_error_(false) {
source_tree_.MapPath("", proto_path); std::vector<std::string> service_list;
error_printer_.reset(new ErrorPrinter(this)); if (channel) {
importer_.reset(new google::protobuf::compiler::Importer( reflection_db_.reset(new grpc::ProtoReflectionDescriptorDatabase(channel));
&source_tree_, error_printer_.get())); reflection_db_->GetServices(&service_list);
const auto* file_desc = importer_->Import(file_name);
if (!file_desc) {
LogError("");
return;
} }
dynamic_factory_.reset(
new google::protobuf::DynamicMessageFactory(importer_->pool()));
std::vector<const google::protobuf::ServiceDescriptor*> service_desc_list; if (!protofiles.empty()) {
for (int i = 0; i < file_desc->service_count(); i++) { source_tree_.MapPath("", proto_path);
service_desc_list.push_back(file_desc->service(i)); error_printer_.reset(new ErrorPrinter(this));
} importer_.reset(new google::protobuf::compiler::Importer(
InitProtoFileParser(method, service_desc_list); &source_tree_, error_printer_.get()));
}
ProtoFileParser::ProtoFileParser(std::shared_ptr<grpc::Channel> channel, grpc::string file_name;
const grpc::string& method) std::stringstream ss(protofiles);
: has_error_(false), while (std::getline(ss, file_name, ',')) {
desc_db_(new grpc::ProtoReflectionDescriptorDatabase(channel)), std::cerr << file_name << std::endl;
desc_pool_(new google::protobuf::DescriptorPool(desc_db_.get())) { const auto* file_desc = importer_->Import(file_name);
std::vector<std::string> service_list; if (file_desc) {
if (!desc_db_->GetServices(&service_list)) { for (int i = 0; i < file_desc->service_count(); i++) {
LogError( service_desc_list_.push_back(file_desc->service(i));
"Failed to get services from the server, " }
"it may not have the reflection service.\n" } else {
"Please try to use the --protofiles option to provide a proto file."); std::cerr << file_name << " not found" << std::endl;
}
}
file_db_.reset(
new google::protobuf::DescriptorPoolDatabase(*importer_->pool()));
} }
if (has_error_) {
if (!reflection_db_ && !file_db_) {
LogError("No available proto database");
return; return;
} }
if (!reflection_db_) {
desc_db_ = std::move(file_db_);
} else if (!file_db_) {
desc_db_ = std::move(reflection_db_);
} else {
desc_db_.reset(new google::protobuf::MergedDescriptorDatabase(
reflection_db_.get(), file_db_.get()));
}
desc_pool_.reset(new google::protobuf::DescriptorPool(desc_db_.get()));
dynamic_factory_.reset( dynamic_factory_.reset(
new google::protobuf::DynamicMessageFactory(desc_pool_.get())); new google::protobuf::DynamicMessageFactory(desc_pool_.get()));
std::vector<const google::protobuf::ServiceDescriptor*> service_desc_list;
for (auto it = service_list.begin(); it != service_list.end(); it++) { for (auto it = service_list.begin(); it != service_list.end(); it++) {
service_desc_list.push_back(desc_pool_->FindServiceByName(*it)); if (const google::protobuf::ServiceDescriptor* service_desc =
desc_pool_->FindServiceByName(*it)) {
service_desc_list_.push_back(service_desc);
}
} }
InitProtoFileParser(method, service_desc_list);
} }
ProtoFileParser::~ProtoFileParser() {} ProtoFileParser::~ProtoFileParser() {}
void ProtoFileParser::InitProtoFileParser( grpc::string ProtoFileParser::GetFullMethodName(const grpc::string& method) {
const grpc::string& method, has_error_ = false;
const std::vector<const google::protobuf::ServiceDescriptor*>
service_desc_list) {
const google::protobuf::MethodDescriptor* method_descriptor = nullptr; const google::protobuf::MethodDescriptor* method_descriptor = nullptr;
for (auto it = service_desc_list.begin(); it != service_desc_list.end(); for (auto it = service_desc_list_.begin(); it != service_desc_list_.end();
it++) { it++) {
const auto* service_desc = *it; const auto* service_desc = *it;
for (int j = 0; j < service_desc->method_count(); j++) { for (int j = 0; j < service_desc->method_count(); j++) {
@ -154,26 +164,80 @@ void ProtoFileParser::InitProtoFileParser(
LogError("Method name not found"); LogError("Method name not found");
} }
if (has_error_) { if (has_error_) {
return; return "";
} }
full_method_name_ = method_descriptor->full_name();
size_t last_dot = full_method_name_.find_last_of('.'); return method_descriptor->full_name();
}
grpc::string ProtoFileParser::GetFormatedMethodName(
const grpc::string& method) {
has_error_ = false;
grpc::string formated_method_name = GetFullMethodName(method);
if (has_error_) {
return "";
}
size_t last_dot = formated_method_name.find_last_of('.');
if (last_dot != grpc::string::npos) { if (last_dot != grpc::string::npos) {
full_method_name_[last_dot] = '/'; formated_method_name[last_dot] = '/';
}
formated_method_name.insert(formated_method_name.begin(), '/');
return formated_method_name;
}
grpc::string ProtoFileParser::GetMessageTypeFromMethod(
const grpc::string& method, bool is_request) {
has_error_ = false;
grpc::string full_method_name = GetFullMethodName(method);
if (has_error_) {
return "";
}
const google::protobuf::MethodDescriptor* method_desc =
desc_pool_->FindMethodByName(full_method_name);
if (!method_desc) {
LogError("Method not found");
return "";
} }
full_method_name_.insert(full_method_name_.begin(), '/');
request_prototype_.reset( return is_request ? method_desc->input_type()->full_name()
dynamic_factory_->GetPrototype(method_descriptor->input_type())->New()); : method_desc->output_type()->full_name();
response_prototype_.reset(
dynamic_factory_->GetPrototype(method_descriptor->output_type())->New());
} }
grpc::string ProtoFileParser::GetSerializedProto( grpc::string ProtoFileParser::GetSerializedProtoFromMethod(
const grpc::string& text_format_proto, bool is_request) { const grpc::string& method, const grpc::string& text_format_proto,
bool is_request) {
has_error_ = false;
grpc::string message_type_name = GetMessageTypeFromMethod(method, is_request);
if (has_error_) {
return "";
}
return GetSerializedProtoFromMessageType(message_type_name,
text_format_proto);
}
grpc::string ProtoFileParser::GetTextFormatFromMethod(
const grpc::string& method, const grpc::string& serialized_proto,
bool is_request) {
has_error_ = false;
grpc::string message_type_name = GetMessageTypeFromMethod(method, is_request);
if (has_error_) {
return "";
}
return GetTextFormatFromMessageType(message_type_name, serialized_proto);
}
grpc::string ProtoFileParser::GetSerializedProtoFromMessageType(
const grpc::string& message_type_name,
const grpc::string& text_format_proto) {
has_error_ = false;
grpc::string serialized; grpc::string serialized;
grpc::protobuf::Message* msg = const google::protobuf::Descriptor* desc =
is_request ? request_prototype_.get() : response_prototype_.get(); desc_pool_->FindMessageTypeByName(message_type_name);
if (!desc) {
LogError("Message type not found");
return "";
}
grpc::protobuf::Message* msg = dynamic_factory_->GetPrototype(desc)->New();
bool ok = bool ok =
google::protobuf::TextFormat::ParseFromString(text_format_proto, msg); google::protobuf::TextFormat::ParseFromString(text_format_proto, msg);
if (!ok) { if (!ok) {
@ -188,10 +252,17 @@ grpc::string ProtoFileParser::GetSerializedProto(
return serialized; return serialized;
} }
grpc::string ProtoFileParser::GetTextFormat( grpc::string ProtoFileParser::GetTextFormatFromMessageType(
const grpc::string& serialized_proto, bool is_request) { const grpc::string& message_type_name,
grpc::protobuf::Message* msg = const grpc::string& serialized_proto) {
is_request ? request_prototype_.get() : response_prototype_.get(); has_error_ = false;
const google::protobuf::Descriptor* desc =
desc_pool_->FindMessageTypeByName(message_type_name);
if (!desc) {
LogError("Message type not found");
return "";
}
grpc::protobuf::Message* msg = dynamic_factory_->GetPrototype(desc)->New();
if (!msg->ParseFromString(serialized_proto)) { if (!msg->ParseFromString(serialized_proto)) {
LogError("Failed to deserialize proto."); LogError("Failed to deserialize proto.");
return ""; return "";
@ -206,7 +277,7 @@ grpc::string ProtoFileParser::GetTextFormat(
void ProtoFileParser::LogError(const grpc::string& error_msg) { void ProtoFileParser::LogError(const grpc::string& error_msg) {
if (!error_msg.empty()) { if (!error_msg.empty()) {
std::cout << error_msg << std::endl; std::cerr << error_msg << std::endl;
} }
has_error_ = true; has_error_ = true;
} }

@ -53,41 +53,57 @@ class ProtoFileParser {
// The given proto file_name will be searched in a source tree rooted from // The given proto file_name will be searched in a source tree rooted from
// proto_path. The method could be a partial string such as Service.Method or // proto_path. The method could be a partial string such as Service.Method or
// even just Method. It will log an error if there is ambiguity. // even just Method. It will log an error if there is ambiguity.
ProtoFileParser(const grpc::string& proto_path, const grpc::string& file_name,
const grpc::string& method);
ProtoFileParser(std::shared_ptr<grpc::Channel> channel, ProtoFileParser(std::shared_ptr<grpc::Channel> channel,
const grpc::string& method); const grpc::string& proto_path,
const grpc::string& protofiles);
~ProtoFileParser(); ~ProtoFileParser();
grpc::string GetFullMethodName() const { return full_method_name_; } // Full method name is in the form of Service.Method, it's good to be used in
// descriptor database queries.
grpc::string GetFullMethodName(const grpc::string& method);
// Formated method name is in the form of /Service/Method, it's good to be
// used as the argument of Stub::Call()
grpc::string GetFormatedMethodName(const grpc::string& method);
grpc::string GetSerializedProtoFromMethod(
const grpc::string& method, const grpc::string& text_format_proto,
bool is_request);
grpc::string GetTextFormatFromMethod(const grpc::string& method,
const grpc::string& serialized_proto,
bool is_request);
grpc::string GetSerializedProto(const grpc::string& text_format_proto, grpc::string GetSerializedProtoFromMessageType(
bool is_request); const grpc::string& message_type_name,
const grpc::string& text_format_proto);
grpc::string GetTextFormat(const grpc::string& serialized_proto, grpc::string GetTextFormatFromMessageType(
bool is_request); const grpc::string& message_type_name,
const grpc::string& serialized_proto);
bool HasError() const { return has_error_; } bool HasError() const { return has_error_; }
void LogError(const grpc::string& error_msg); void LogError(const grpc::string& error_msg);
private: private:
void InitProtoFileParser( grpc::string GetMessageTypeFromMethod(const grpc::string& method,
const grpc::string& method, bool is_request);
const std::vector<const google::protobuf::ServiceDescriptor*> services);
bool has_error_; bool has_error_;
grpc::string request_text_; grpc::string request_text_;
grpc::string full_method_name_;
google::protobuf::compiler::DiskSourceTree source_tree_; google::protobuf::compiler::DiskSourceTree source_tree_;
std::unique_ptr<ErrorPrinter> error_printer_; std::unique_ptr<ErrorPrinter> error_printer_;
std::unique_ptr<google::protobuf::compiler::Importer> importer_; std::unique_ptr<google::protobuf::compiler::Importer> importer_;
std::unique_ptr<grpc::ProtoReflectionDescriptorDatabase> desc_db_; std::unique_ptr<grpc::ProtoReflectionDescriptorDatabase> reflection_db_;
std::unique_ptr<google::protobuf::DescriptorPoolDatabase> file_db_;
std::unique_ptr<google::protobuf::DescriptorDatabase> desc_db_;
std::unique_ptr<google::protobuf::DescriptorPool> desc_pool_; std::unique_ptr<google::protobuf::DescriptorPool> desc_pool_;
std::unique_ptr<google::protobuf::DynamicMessageFactory> dynamic_factory_; std::unique_ptr<google::protobuf::DynamicMessageFactory> dynamic_factory_;
std::unique_ptr<grpc::protobuf::Message> request_prototype_; std::unique_ptr<grpc::protobuf::Message> request_prototype_;
std::unique_ptr<grpc::protobuf::Message> response_prototype_; std::unique_ptr<grpc::protobuf::Message> response_prototype_;
std::vector<const google::protobuf::ServiceDescriptor*> service_desc_list_;
}; };
} // namespace testing } // namespace testing

Loading…
Cancel
Save