diff --git a/test/cpp/util/grpc_tool.cc b/test/cpp/util/grpc_tool.cc index e227e6027dc..a181ee7597b 100644 --- a/test/cpp/util/grpc_tool.cc +++ b/test/cpp/util/grpc_tool.cc @@ -31,7 +31,7 @@ * */ -#include "test/cpp/util/grpc_tool.h" +#include "grpc_tool.h" #include #include @@ -55,15 +55,14 @@ DEFINE_bool(enable_ssl, false, "Whether to use ssl/tls."); DEFINE_bool(use_auth, false, "Whether to create default google credentials."); -DEFINE_string(input_binary_file, "", - "Path to input file containing serialized request."); -DEFINE_string(output_binary_file, "", - "Path to output file to write serialized response."); +DEFINE_bool(remotedb, true, "Use server types to parse and format messages"); DEFINE_string(metadata, "", "Metadata to send to server, in the form of key1:val1:key2:val2"); DEFINE_string(proto_path, ".", "Path to look for the proto file."); -// TODO(zyc): support a list of input proto files -DEFINE_string(protofiles, "", "Name of the proto file."); +DEFINE_string(proto_file, "", "Name of the proto file."); +DEFINE_bool(binary_input, false, "Input in binary format"); +DEFINE_bool(binary_output, false, "Output in binary format"); +DEFINE_string(infile, "", "Input file (default is stdin)"); namespace grpc { namespace testing { @@ -73,8 +72,22 @@ class GrpcTool { public: explicit GrpcTool(); virtual ~GrpcTool() {} + bool Help(int argc, const char** argv, GrpcToolOutputCallback callback); bool CallMethod(int argc, const char** argv, GrpcToolOutputCallback callback); + // TODO(zyc): implement the following methods + // bool ListServices(int argc, const char** argv, GrpcToolOutputCallback + // callback); + // bool PrintType(int argc, const char** argv, GrpcToolOutputCallback + // callback); + // bool PrintTypeId(int argc, const char** argv, GrpcToolOutputCallback + // callback); + // bool ParseMessage(int argc, const char** argv, GrpcToolOutputCallback + // callback); + // bool ToText(int argc, const char** argv, GrpcToolOutputCallback callback); + // bool ToBinary(int argc, const char** argv, GrpcToolOutputCallback + // callback); + void SetPrintCommandMode(int exit_status) { print_command_usage_ = true; usage_exit_status_ = exit_status; @@ -82,6 +95,7 @@ class GrpcTool { private: void CommandUsage(const grpc::string& usage) const; + std::shared_ptr NewChannel(const grpc::string& server_address); bool print_command_usage_; int usage_exit_status_; }; @@ -222,6 +236,21 @@ void GrpcTool::CommandUsage(const grpc::string& usage) const { } } +std::shared_ptr GrpcTool::NewChannel( + const grpc::string& server_address) { + std::shared_ptr creds; + if (!FLAGS_enable_ssl) { + creds = grpc::InsecureChannelCredentials(); + } else { + if (FLAGS_use_auth) { + creds = grpc::GoogleDefaultCredentials(); + } else { + creds = grpc::SslCredentials(grpc::SslCredentialsOptions()); + } + } + return grpc::CreateChannel(server_address, creds); +} + bool GrpcTool::Help(int argc, const char** argv, GrpcToolOutputCallback callback) { CommandUsage( @@ -250,17 +279,18 @@ bool GrpcTool::CallMethod(int argc, const char** argv, " ; Exported service name\n" " ; Method name\n" " ; Text protobuffer (overrides infile)\n" - " --protofiles ; Comma separated proto files used as a" + " --proto_file ; Comma separated proto files used as a" " fallback when parsing request/response\n" " --proto_path ; The search path of proto files, valid" - " only when --protofiles is given\n" + " only when --proto_file is given\n" " --metadata ; The metadata to be sent to the server\n" " --enable_ssl ; Set whether to use tls\n" " --use_auth ; Set whether to create default google" " credentials\n" + " --infile ; Input filename (defaults to stdin)\n" " --outfile ; Output filename (defaults to stdout)\n" - " --input_binary_file ; Path to input file in binary format\n" - " --binary_output ; Path to output file in binary format\n"); + " --binary_input ; Input in binary format\n" + " --binary_output ; Output in binary format\n"); std::stringstream output_ss; grpc::string request_text; @@ -271,63 +301,44 @@ bool GrpcTool::CallMethod(int argc, const char** argv, if (argc == 3) { request_text = argv[2]; - } - - std::shared_ptr creds; - if (!FLAGS_enable_ssl) { - creds = grpc::InsecureChannelCredentials(); + if (!FLAGS_infile.empty()) { + fprintf(stderr, "warning: request given in argv, ignoring --infile\n"); + } } else { - if (FLAGS_use_auth) { - creds = grpc::GoogleDefaultCredentials(); + std::stringstream input_stream; + if (FLAGS_infile.empty()) { + if (isatty(STDIN_FILENO)) { + fprintf(stderr, "reading request message from stdin...\n"); + } + input_stream << std::cin.rdbuf(); } else { - creds = grpc::SslCredentials(grpc::SslCredentialsOptions()); - } - } - std::shared_ptr channel = - grpc::CreateChannel(server_address, creds); - - if (request_text.empty() && FLAGS_input_binary_file.empty()) { - if (isatty(STDIN_FILENO)) { - std::cout << "reading request message from stdin..." << std::endl; + std::ifstream input_file(FLAGS_infile, std::ios::in | std::ios::binary); + input_stream << input_file.rdbuf(); + input_file.close(); } - std::stringstream input_stream; - input_stream << std::cin.rdbuf(); request_text = input_stream.str(); } - if (!request_text.empty()) { - if (!FLAGS_protofiles.empty()) { - parser.reset(new grpc::testing::ProtoFileParser( - FLAGS_proto_path, FLAGS_protofiles, method_name)); - } else { - parser.reset(new grpc::testing::ProtoFileParser(channel, method_name)); - } - method_name = parser->GetFullMethodName(); + std::shared_ptr channel = NewChannel(server_address); + if (!FLAGS_binary_input || !FLAGS_binary_output) { + parser.reset( + new grpc::testing::ProtoFileParser(FLAGS_remotedb ? channel : nullptr, + FLAGS_proto_path, FLAGS_proto_file)); if (parser->HasError()) { - return 1; - } - - if (!FLAGS_input_binary_file.empty()) { - std::cout - << "warning: request given in argv, ignoring --input_binary_file" - << std::endl; + return false; } } - if (parser) { - serialized_request_proto = - parser->GetSerializedProto(request_text, true /* is_request */); + if (FLAGS_binary_input) { + serialized_request_proto = request_text; + } else { + serialized_request_proto = parser->GetSerializedProtoFromMethod( + method_name, request_text, true /* is_request */); if (parser->HasError()) { - return 1; + return false; } - } else if (!FLAGS_input_binary_file.empty()) { - std::ifstream input_file(FLAGS_input_binary_file, - std::ios::in | std::ios::binary); - std::stringstream input_stream; - input_stream << input_file.rdbuf(); - serialized_request_proto = input_stream.str(); } - std::cout << "connecting to " << server_address << std::endl; + std::cerr << "connecting to " << server_address << std::endl; grpc::string serialized_response_proto; std::multimap client_metadata; @@ -336,30 +347,27 @@ bool GrpcTool::CallMethod(int argc, const char** argv, ParseMetadataFlag(&client_metadata); PrintMetadata(client_metadata, "Sending client initial metadata:"); grpc::Status s = grpc::testing::CliCall::Call( - channel, method_name, serialized_request_proto, - &serialized_response_proto, client_metadata, &server_initial_metadata, - &server_trailing_metadata); + channel, parser->GetFormatedMethodName(method_name), + serialized_request_proto, &serialized_response_proto, client_metadata, + &server_initial_metadata, &server_trailing_metadata); PrintMetadata(server_initial_metadata, "Received initial metadata from server:"); PrintMetadata(server_trailing_metadata, "Received trailing metadata from server:"); if (s.ok()) { - std::cout << "Rpc succeeded with OK status" << std::endl; - if (parser) { - grpc::string response_text = parser->GetTextFormat( - serialized_response_proto, false /* is_request */); + std::cerr << "Rpc succeeded with OK status" << std::endl; + if (FLAGS_binary_output) { + output_ss << serialized_response_proto; + } else { + grpc::string response_text = parser->GetTextFormatFromMethod( + method_name, serialized_response_proto, false /* is_request */); if (parser->HasError()) { return false; } output_ss << "Response: \n " << response_text << std::endl; } - if (!FLAGS_output_binary_file.empty()) { - std::ofstream output_file(FLAGS_output_binary_file, - std::ios::trunc | std::ios::binary); - output_file << serialized_response_proto; - } } else { - std::cout << "Rpc failed with status code " << s.error_code() + std::cerr << "Rpc failed with status code " << s.error_code() << ", error message: " << s.error_message() << std::endl; } diff --git a/test/cpp/util/proto_file_parser.cc b/test/cpp/util/proto_file_parser.cc index 5b0d925e1c3..f7ad17e809c 100644 --- a/test/cpp/util/proto_file_parser.cc +++ b/test/cpp/util/proto_file_parser.cc @@ -71,7 +71,7 @@ class ErrorPrinter void AddWarning(const grpc::string& filename, int line, int column, const grpc::string& message) GRPC_OVERRIDE { - std::cout << "warning " << filename << " " << line << " " << column << " " + std::cerr << "warning " << filename << " " << line << " " << column << " " << message << std::endl; } @@ -79,62 +79,72 @@ class ErrorPrinter ProtoFileParser* parser_; // not owned }; -ProtoFileParser::ProtoFileParser(const grpc::string& proto_path, - const grpc::string& file_name, - const grpc::string& method) +ProtoFileParser::ProtoFileParser(std::shared_ptr channel, + const grpc::string& proto_path, + const grpc::string& protofiles) : has_error_(false) { - source_tree_.MapPath("", proto_path); - error_printer_.reset(new ErrorPrinter(this)); - importer_.reset(new google::protobuf::compiler::Importer( - &source_tree_, error_printer_.get())); - const auto* file_desc = importer_->Import(file_name); - if (!file_desc) { - LogError(""); - return; + std::vector service_list; + if (channel) { + reflection_db_.reset(new grpc::ProtoReflectionDescriptorDatabase(channel)); + reflection_db_->GetServices(&service_list); } - dynamic_factory_.reset( - new google::protobuf::DynamicMessageFactory(importer_->pool())); - std::vector service_desc_list; - for (int i = 0; i < file_desc->service_count(); i++) { - service_desc_list.push_back(file_desc->service(i)); - } - InitProtoFileParser(method, service_desc_list); -} + if (!protofiles.empty()) { + source_tree_.MapPath("", proto_path); + error_printer_.reset(new ErrorPrinter(this)); + importer_.reset(new google::protobuf::compiler::Importer( + &source_tree_, error_printer_.get())); -ProtoFileParser::ProtoFileParser(std::shared_ptr channel, - const grpc::string& method) - : has_error_(false), - desc_db_(new grpc::ProtoReflectionDescriptorDatabase(channel)), - desc_pool_(new google::protobuf::DescriptorPool(desc_db_.get())) { - std::vector service_list; - if (!desc_db_->GetServices(&service_list)) { - LogError( - "Failed to get services from the server, " - "it may not have the reflection service.\n" - "Please try to use the --protofiles option to provide a proto file."); + grpc::string file_name; + std::stringstream ss(protofiles); + while (std::getline(ss, file_name, ',')) { + std::cerr << file_name << std::endl; + const auto* file_desc = importer_->Import(file_name); + if (file_desc) { + for (int i = 0; i < file_desc->service_count(); i++) { + service_desc_list_.push_back(file_desc->service(i)); + } + } else { + std::cerr << file_name << " not found" << std::endl; + } + } + + file_db_.reset( + new google::protobuf::DescriptorPoolDatabase(*importer_->pool())); } - if (has_error_) { + + if (!reflection_db_ && !file_db_) { + LogError("No available proto database"); return; } + + if (!reflection_db_) { + desc_db_ = std::move(file_db_); + } else if (!file_db_) { + desc_db_ = std::move(reflection_db_); + } else { + desc_db_.reset(new google::protobuf::MergedDescriptorDatabase( + reflection_db_.get(), file_db_.get())); + } + + desc_pool_.reset(new google::protobuf::DescriptorPool(desc_db_.get())); dynamic_factory_.reset( new google::protobuf::DynamicMessageFactory(desc_pool_.get())); - std::vector service_desc_list; for (auto it = service_list.begin(); it != service_list.end(); it++) { - service_desc_list.push_back(desc_pool_->FindServiceByName(*it)); + if (const google::protobuf::ServiceDescriptor* service_desc = + desc_pool_->FindServiceByName(*it)) { + service_desc_list_.push_back(service_desc); + } } - InitProtoFileParser(method, service_desc_list); } ProtoFileParser::~ProtoFileParser() {} -void ProtoFileParser::InitProtoFileParser( - const grpc::string& method, - const std::vector - service_desc_list) { +grpc::string ProtoFileParser::GetFullMethodName(const grpc::string& method) { + has_error_ = false; const google::protobuf::MethodDescriptor* method_descriptor = nullptr; - for (auto it = service_desc_list.begin(); it != service_desc_list.end(); + for (auto it = service_desc_list_.begin(); it != service_desc_list_.end(); it++) { const auto* service_desc = *it; for (int j = 0; j < service_desc->method_count(); j++) { @@ -154,26 +164,80 @@ void ProtoFileParser::InitProtoFileParser( LogError("Method name not found"); } if (has_error_) { - return; + return ""; } - full_method_name_ = method_descriptor->full_name(); - size_t last_dot = full_method_name_.find_last_of('.'); + + return method_descriptor->full_name(); +} + +grpc::string ProtoFileParser::GetFormatedMethodName( + const grpc::string& method) { + has_error_ = false; + grpc::string formated_method_name = GetFullMethodName(method); + if (has_error_) { + return ""; + } + size_t last_dot = formated_method_name.find_last_of('.'); if (last_dot != grpc::string::npos) { - full_method_name_[last_dot] = '/'; + formated_method_name[last_dot] = '/'; + } + formated_method_name.insert(formated_method_name.begin(), '/'); + return formated_method_name; +} + +grpc::string ProtoFileParser::GetMessageTypeFromMethod( + const grpc::string& method, bool is_request) { + has_error_ = false; + grpc::string full_method_name = GetFullMethodName(method); + if (has_error_) { + return ""; + } + const google::protobuf::MethodDescriptor* method_desc = + desc_pool_->FindMethodByName(full_method_name); + if (!method_desc) { + LogError("Method not found"); + return ""; } - full_method_name_.insert(full_method_name_.begin(), '/'); - request_prototype_.reset( - dynamic_factory_->GetPrototype(method_descriptor->input_type())->New()); - response_prototype_.reset( - dynamic_factory_->GetPrototype(method_descriptor->output_type())->New()); + return is_request ? method_desc->input_type()->full_name() + : method_desc->output_type()->full_name(); } -grpc::string ProtoFileParser::GetSerializedProto( - const grpc::string& text_format_proto, bool is_request) { +grpc::string ProtoFileParser::GetSerializedProtoFromMethod( + const grpc::string& method, const grpc::string& text_format_proto, + bool is_request) { + has_error_ = false; + grpc::string message_type_name = GetMessageTypeFromMethod(method, is_request); + if (has_error_) { + return ""; + } + return GetSerializedProtoFromMessageType(message_type_name, + text_format_proto); +} + +grpc::string ProtoFileParser::GetTextFormatFromMethod( + const grpc::string& method, const grpc::string& serialized_proto, + bool is_request) { + has_error_ = false; + grpc::string message_type_name = GetMessageTypeFromMethod(method, is_request); + if (has_error_) { + return ""; + } + return GetTextFormatFromMessageType(message_type_name, serialized_proto); +} + +grpc::string ProtoFileParser::GetSerializedProtoFromMessageType( + const grpc::string& message_type_name, + const grpc::string& text_format_proto) { + has_error_ = false; grpc::string serialized; - grpc::protobuf::Message* msg = - is_request ? request_prototype_.get() : response_prototype_.get(); + const google::protobuf::Descriptor* desc = + desc_pool_->FindMessageTypeByName(message_type_name); + if (!desc) { + LogError("Message type not found"); + return ""; + } + grpc::protobuf::Message* msg = dynamic_factory_->GetPrototype(desc)->New(); bool ok = google::protobuf::TextFormat::ParseFromString(text_format_proto, msg); if (!ok) { @@ -188,10 +252,17 @@ grpc::string ProtoFileParser::GetSerializedProto( return serialized; } -grpc::string ProtoFileParser::GetTextFormat( - const grpc::string& serialized_proto, bool is_request) { - grpc::protobuf::Message* msg = - is_request ? request_prototype_.get() : response_prototype_.get(); +grpc::string ProtoFileParser::GetTextFormatFromMessageType( + const grpc::string& message_type_name, + const grpc::string& serialized_proto) { + has_error_ = false; + const google::protobuf::Descriptor* desc = + desc_pool_->FindMessageTypeByName(message_type_name); + if (!desc) { + LogError("Message type not found"); + return ""; + } + grpc::protobuf::Message* msg = dynamic_factory_->GetPrototype(desc)->New(); if (!msg->ParseFromString(serialized_proto)) { LogError("Failed to deserialize proto."); return ""; @@ -206,7 +277,7 @@ grpc::string ProtoFileParser::GetTextFormat( void ProtoFileParser::LogError(const grpc::string& error_msg) { if (!error_msg.empty()) { - std::cout << error_msg << std::endl; + std::cerr << error_msg << std::endl; } has_error_ = true; } diff --git a/test/cpp/util/proto_file_parser.h b/test/cpp/util/proto_file_parser.h index b442d77db98..c8afc4b4b71 100644 --- a/test/cpp/util/proto_file_parser.h +++ b/test/cpp/util/proto_file_parser.h @@ -53,41 +53,57 @@ class ProtoFileParser { // The given proto file_name will be searched in a source tree rooted from // proto_path. The method could be a partial string such as Service.Method or // even just Method. It will log an error if there is ambiguity. - ProtoFileParser(const grpc::string& proto_path, const grpc::string& file_name, - const grpc::string& method); - ProtoFileParser(std::shared_ptr channel, - const grpc::string& method); + const grpc::string& proto_path, + const grpc::string& protofiles); + ~ProtoFileParser(); - grpc::string GetFullMethodName() const { return full_method_name_; } + // Full method name is in the form of Service.Method, it's good to be used in + // descriptor database queries. + grpc::string GetFullMethodName(const grpc::string& method); + + // Formated method name is in the form of /Service/Method, it's good to be + // used as the argument of Stub::Call() + grpc::string GetFormatedMethodName(const grpc::string& method); + + grpc::string GetSerializedProtoFromMethod( + const grpc::string& method, const grpc::string& text_format_proto, + bool is_request); + + grpc::string GetTextFormatFromMethod(const grpc::string& method, + const grpc::string& serialized_proto, + bool is_request); - grpc::string GetSerializedProto(const grpc::string& text_format_proto, - bool is_request); + grpc::string GetSerializedProtoFromMessageType( + const grpc::string& message_type_name, + const grpc::string& text_format_proto); - grpc::string GetTextFormat(const grpc::string& serialized_proto, - bool is_request); + grpc::string GetTextFormatFromMessageType( + const grpc::string& message_type_name, + const grpc::string& serialized_proto); bool HasError() const { return has_error_; } void LogError(const grpc::string& error_msg); private: - void InitProtoFileParser( - const grpc::string& method, - const std::vector services); + grpc::string GetMessageTypeFromMethod(const grpc::string& method, + bool is_request); bool has_error_; grpc::string request_text_; - grpc::string full_method_name_; google::protobuf::compiler::DiskSourceTree source_tree_; std::unique_ptr error_printer_; std::unique_ptr importer_; - std::unique_ptr desc_db_; + std::unique_ptr reflection_db_; + std::unique_ptr file_db_; + std::unique_ptr desc_db_; std::unique_ptr desc_pool_; std::unique_ptr dynamic_factory_; std::unique_ptr request_prototype_; std::unique_ptr response_prototype_; + std::vector service_desc_list_; }; } // namespace testing