Introduce GRPC_ARG_TSI_MAX_FRAME_SIZE channel arg.

Introduce GRPC_ARG_TSI_MAX_FRAME_SIZE so that users can use larger than
14KiB frame size if they need to.
pull/20268/head
Soheil Hassas Yeganeh 5 years ago
parent 11729aeaa7
commit dd6e6e3ef7
  1. 8
      include/grpc/impl/codegen/grpc_types.h
  2. 16
      src/core/lib/http/httpcli_security_connector.cc
  3. 8
      src/core/lib/security/security_connector/alts/alts_security_connector.cc
  4. 10
      src/core/lib/security/security_connector/fake/fake_security_connector.cc
  5. 8
      src/core/lib/security/security_connector/local/local_security_connector.cc
  6. 1
      src/core/lib/security/security_connector/security_connector.cc
  7. 22
      src/core/lib/security/security_connector/security_connector.h
  8. 10
      src/core/lib/security/security_connector/ssl/ssl_security_connector.cc
  9. 8
      src/core/lib/security/security_connector/tls/spiffe_security_connector.cc
  10. 6
      src/core/lib/security/security_connector/tls/spiffe_security_connector.h
  11. 38
      src/core/lib/security/transport/security_handshaker.cc
  12. 6
      src/core/lib/security/transport/security_handshaker.h
  13. 2
      src/core/tsi/alts/frame_protector/alts_frame_protector.cc

@ -267,6 +267,14 @@ typedef struct {
grpc_ssl_session_cache*). (use grpc_ssl_session_cache_arg_vtable() to fetch grpc_ssl_session_cache*). (use grpc_ssl_session_cache_arg_vtable() to fetch
an appropriate pointer arg vtable) */ an appropriate pointer arg vtable) */
#define GRPC_SSL_SESSION_CACHE_ARG "grpc.ssl_session_cache" #define GRPC_SSL_SESSION_CACHE_ARG "grpc.ssl_session_cache"
/** If non-zero, it will determine the maximum frame size used by TSI's frame
* protector.
*
* NOTE: Be aware that using a large "max_frame_size" is memory inefficient
* for non-zerocopy protectors. Also, increasing this value above 1MiB
* can break old binaries that don't support larger than 1MiB frame
* size. */
#define GRPC_ARG_TSI_MAX_FRAME_SIZE "grpc.tsi.max_frame_size"
/** Maximum metadata size, in bytes. Note this limit applies to the max sum of /** Maximum metadata size, in bytes. Note this limit applies to the max sum of
all metadata key-value entries in a batch of headers. */ all metadata key-value entries in a batch of headers. */
#define GRPC_ARG_MAX_METADATA_SIZE "grpc.max_metadata_size" #define GRPC_ARG_MAX_METADATA_SIZE "grpc.max_metadata_size"

@ -41,7 +41,7 @@
class grpc_httpcli_ssl_channel_security_connector final class grpc_httpcli_ssl_channel_security_connector final
: public grpc_channel_security_connector { : public grpc_channel_security_connector {
public: public:
explicit grpc_httpcli_ssl_channel_security_connector(char* secure_peer_name) grpc_httpcli_ssl_channel_security_connector(char* secure_peer_name)
: grpc_channel_security_connector( : grpc_channel_security_connector(
/*url_scheme=*/nullptr, /*url_scheme=*/nullptr,
/*channel_creds=*/nullptr, /*channel_creds=*/nullptr,
@ -66,7 +66,8 @@ class grpc_httpcli_ssl_channel_security_connector final
&options, &handshaker_factory_); &options, &handshaker_factory_);
} }
void add_handshakers(grpc_pollset_set* interested_parties, void add_handshakers(const grpc_channel_args* args,
grpc_pollset_set* interested_parties,
grpc_core::HandshakeManager* handshake_mgr) override { grpc_core::HandshakeManager* handshake_mgr) override {
tsi_handshaker* handshaker = nullptr; tsi_handshaker* handshaker = nullptr;
if (handshaker_factory_ != nullptr) { if (handshaker_factory_ != nullptr) {
@ -77,7 +78,8 @@ class grpc_httpcli_ssl_channel_security_connector final
tsi_result_to_string(result)); tsi_result_to_string(result));
} }
} }
handshake_mgr->Add(grpc_core::SecurityHandshakerCreate(handshaker, this)); handshake_mgr->Add(
grpc_core::SecurityHandshakerCreate(handshaker, this, args));
} }
tsi_ssl_client_handshaker_factory* handshaker_factory() const { tsi_ssl_client_handshaker_factory* handshaker_factory() const {
@ -132,7 +134,7 @@ class grpc_httpcli_ssl_channel_security_connector final
static grpc_core::RefCountedPtr<grpc_channel_security_connector> static grpc_core::RefCountedPtr<grpc_channel_security_connector>
httpcli_ssl_channel_security_connector_create( httpcli_ssl_channel_security_connector_create(
const char* pem_root_certs, const tsi_ssl_root_certs_store* root_store, const char* pem_root_certs, const tsi_ssl_root_certs_store* root_store,
const char* secure_peer_name) { const char* secure_peer_name, grpc_channel_args* channel_args) {
if (secure_peer_name != nullptr && pem_root_certs == nullptr) { if (secure_peer_name != nullptr && pem_root_certs == nullptr) {
gpr_log(GPR_ERROR, gpr_log(GPR_ERROR,
"Cannot assert a secure peer name without a trust root."); "Cannot assert a secure peer name without a trust root.");
@ -192,8 +194,10 @@ static void ssl_handshake(void* arg, grpc_endpoint* tcp, const char* host,
c->func = on_done; c->func = on_done;
c->arg = arg; c->arg = arg;
grpc_core::RefCountedPtr<grpc_channel_security_connector> sc = grpc_core::RefCountedPtr<grpc_channel_security_connector> sc =
httpcli_ssl_channel_security_connector_create(pem_root_certs, root_store, httpcli_ssl_channel_security_connector_create(
host); pem_root_certs, root_store, host,
static_cast<grpc_core::HandshakerArgs*>(arg)->args);
GPR_ASSERT(sc != nullptr); GPR_ASSERT(sc != nullptr);
grpc_arg channel_arg = grpc_security_connector_to_arg(sc.get()); grpc_arg channel_arg = grpc_security_connector_to_arg(sc.get());
grpc_channel_args args = {1, &channel_arg}; grpc_channel_args args = {1, &channel_arg};

@ -81,7 +81,7 @@ class grpc_alts_channel_security_connector final
~grpc_alts_channel_security_connector() override { gpr_free(target_name_); } ~grpc_alts_channel_security_connector() override { gpr_free(target_name_); }
void add_handshakers( void add_handshakers(
grpc_pollset_set* interested_parties, const grpc_channel_args* args, grpc_pollset_set* interested_parties,
grpc_core::HandshakeManager* handshake_manager) override { grpc_core::HandshakeManager* handshake_manager) override {
tsi_handshaker* handshaker = nullptr; tsi_handshaker* handshaker = nullptr;
const grpc_alts_credentials* creds = const grpc_alts_credentials* creds =
@ -91,7 +91,7 @@ class grpc_alts_channel_security_connector final
interested_parties, interested_parties,
&handshaker) == TSI_OK); &handshaker) == TSI_OK);
handshake_manager->Add( handshake_manager->Add(
grpc_core::SecurityHandshakerCreate(handshaker, this)); grpc_core::SecurityHandshakerCreate(handshaker, this, args));
} }
void check_peer(tsi_peer peer, grpc_endpoint* ep, void check_peer(tsi_peer peer, grpc_endpoint* ep,
@ -142,7 +142,7 @@ class grpc_alts_server_security_connector final
~grpc_alts_server_security_connector() override = default; ~grpc_alts_server_security_connector() override = default;
void add_handshakers( void add_handshakers(
grpc_pollset_set* interested_parties, const grpc_channel_args* args, grpc_pollset_set* interested_parties,
grpc_core::HandshakeManager* handshake_manager) override { grpc_core::HandshakeManager* handshake_manager) override {
tsi_handshaker* handshaker = nullptr; tsi_handshaker* handshaker = nullptr;
const grpc_alts_server_credentials* creds = const grpc_alts_server_credentials* creds =
@ -151,7 +151,7 @@ class grpc_alts_server_security_connector final
creds->options(), nullptr, creds->handshaker_service_url(), creds->options(), nullptr, creds->handshaker_service_url(),
false, interested_parties, &handshaker) == TSI_OK); false, interested_parties, &handshaker) == TSI_OK);
handshake_manager->Add( handshake_manager->Add(
grpc_core::SecurityHandshakerCreate(handshaker, this)); grpc_core::SecurityHandshakerCreate(handshaker, this, args));
} }
void check_peer(tsi_peer peer, grpc_endpoint* ep, void check_peer(tsi_peer peer, grpc_endpoint* ep,

@ -96,10 +96,11 @@ class grpc_fake_channel_security_connector final
return GPR_ICMP(is_lb_channel_, other->is_lb_channel_); return GPR_ICMP(is_lb_channel_, other->is_lb_channel_);
} }
void add_handshakers(grpc_pollset_set* interested_parties, void add_handshakers(const grpc_channel_args* args,
grpc_pollset_set* interested_parties,
grpc_core::HandshakeManager* handshake_mgr) override { grpc_core::HandshakeManager* handshake_mgr) override {
handshake_mgr->Add(grpc_core::SecurityHandshakerCreate( handshake_mgr->Add(grpc_core::SecurityHandshakerCreate(
tsi_create_fake_handshaker(/*is_client=*/true), this)); tsi_create_fake_handshaker(/*is_client=*/true), this, args));
} }
bool check_call_host(grpc_core::StringView host, bool check_call_host(grpc_core::StringView host,
@ -271,10 +272,11 @@ class grpc_fake_server_security_connector
fake_check_peer(this, peer, auth_context, on_peer_checked); fake_check_peer(this, peer, auth_context, on_peer_checked);
} }
void add_handshakers(grpc_pollset_set* interested_parties, void add_handshakers(const grpc_channel_args* args,
grpc_pollset_set* interested_parties,
grpc_core::HandshakeManager* handshake_mgr) override { grpc_core::HandshakeManager* handshake_mgr) override {
handshake_mgr->Add(grpc_core::SecurityHandshakerCreate( handshake_mgr->Add(grpc_core::SecurityHandshakerCreate(
tsi_create_fake_handshaker(/*=is_client*/ false), this)); tsi_create_fake_handshaker(/*=is_client*/ false), this, args));
} }
int cmp(const grpc_security_connector* other) const override { int cmp(const grpc_security_connector* other) const override {

@ -129,13 +129,13 @@ class grpc_local_channel_security_connector final
~grpc_local_channel_security_connector() override { gpr_free(target_name_); } ~grpc_local_channel_security_connector() override { gpr_free(target_name_); }
void add_handshakers( void add_handshakers(
grpc_pollset_set* interested_parties, const grpc_channel_args* args, grpc_pollset_set* interested_parties,
grpc_core::HandshakeManager* handshake_manager) override { grpc_core::HandshakeManager* handshake_manager) override {
tsi_handshaker* handshaker = nullptr; tsi_handshaker* handshaker = nullptr;
GPR_ASSERT(local_tsi_handshaker_create(true /* is_client */, &handshaker) == GPR_ASSERT(local_tsi_handshaker_create(true /* is_client */, &handshaker) ==
TSI_OK); TSI_OK);
handshake_manager->Add( handshake_manager->Add(
grpc_core::SecurityHandshakerCreate(handshaker, this)); grpc_core::SecurityHandshakerCreate(handshaker, this, args));
} }
int cmp(const grpc_security_connector* other_sc) const override { int cmp(const grpc_security_connector* other_sc) const override {
@ -187,13 +187,13 @@ class grpc_local_server_security_connector final
~grpc_local_server_security_connector() override = default; ~grpc_local_server_security_connector() override = default;
void add_handshakers( void add_handshakers(
grpc_pollset_set* interested_parties, const grpc_channel_args* args, grpc_pollset_set* interested_parties,
grpc_core::HandshakeManager* handshake_manager) override { grpc_core::HandshakeManager* handshake_manager) override {
tsi_handshaker* handshaker = nullptr; tsi_handshaker* handshaker = nullptr;
GPR_ASSERT(local_tsi_handshaker_create(false /* is_client */, GPR_ASSERT(local_tsi_handshaker_create(false /* is_client */,
&handshaker) == TSI_OK); &handshaker) == TSI_OK);
handshake_manager->Add( handshake_manager->Add(
grpc_core::SecurityHandshakerCreate(handshaker, this)); grpc_core::SecurityHandshakerCreate(handshaker, this, args));
} }
void check_peer(tsi_peer peer, grpc_endpoint* ep, void check_peer(tsi_peer peer, grpc_endpoint* ep,

@ -53,6 +53,7 @@ grpc_channel_security_connector::grpc_channel_security_connector(
: grpc_security_connector(url_scheme), : grpc_security_connector(url_scheme),
channel_creds_(std::move(channel_creds)), channel_creds_(std::move(channel_creds)),
request_metadata_creds_(std::move(request_metadata_creds)) {} request_metadata_creds_(std::move(request_metadata_creds)) {}
grpc_channel_security_connector::~grpc_channel_security_connector() {} grpc_channel_security_connector::~grpc_channel_security_connector() {}
int grpc_security_connector_cmp(const grpc_security_connector* sc, int grpc_security_connector_cmp(const grpc_security_connector* sc,

@ -91,7 +91,9 @@ class grpc_channel_security_connector : public grpc_security_connector {
grpc_channel_security_connector( grpc_channel_security_connector(
const char* url_scheme, const char* url_scheme,
grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds, grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds,
grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds); grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds
/*,
grpc_channel_args* channel_args = nullptr*/);
~grpc_channel_security_connector() override; ~grpc_channel_security_connector() override;
/// Checks that the host that will be set for a call is acceptable. /// Checks that the host that will be set for a call is acceptable.
@ -108,9 +110,9 @@ class grpc_channel_security_connector : public grpc_security_connector {
virtual void cancel_check_call_host(grpc_closure* on_call_host_checked, virtual void cancel_check_call_host(grpc_closure* on_call_host_checked,
grpc_error* error) GRPC_ABSTRACT; grpc_error* error) GRPC_ABSTRACT;
/// Registers handshakers with \a handshake_mgr. /// Registers handshakers with \a handshake_mgr.
virtual void add_handshakers(grpc_pollset_set* interested_parties, virtual void add_handshakers(
grpc_core::HandshakeManager* handshake_mgr) const grpc_channel_args* args, grpc_pollset_set* interested_parties,
GRPC_ABSTRACT; grpc_core::HandshakeManager* handshake_mgr) GRPC_ABSTRACT;
const grpc_channel_credentials* channel_creds() const { const grpc_channel_credentials* channel_creds() const {
return channel_creds_.get(); return channel_creds_.get();
@ -132,9 +134,15 @@ class grpc_channel_security_connector : public grpc_security_connector {
int channel_security_connector_cmp( int channel_security_connector_cmp(
const grpc_channel_security_connector* other) const; const grpc_channel_security_connector* other) const;
// grpc_channel_args* channel_args() const { return channel_args_.get(); }
//// Should be called as soon as the channel args are not needed to reduce
//// memory usage.
// void clear_channel_arg() { channel_args_.reset(); }
private: private:
grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds_; grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds_;
grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds_; grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds_;
grpc_core::UniquePtr<grpc_channel_args> channel_args_;
}; };
/* --- server_security_connector object. --- /* --- server_security_connector object. ---
@ -149,9 +157,9 @@ class grpc_server_security_connector : public grpc_security_connector {
grpc_core::RefCountedPtr<grpc_server_credentials> server_creds); grpc_core::RefCountedPtr<grpc_server_credentials> server_creds);
~grpc_server_security_connector() override = default; ~grpc_server_security_connector() override = default;
virtual void add_handshakers(grpc_pollset_set* interested_parties, virtual void add_handshakers(
grpc_core::HandshakeManager* handshake_mgr) const grpc_channel_args* args, grpc_pollset_set* interested_parties,
GRPC_ABSTRACT; grpc_core::HandshakeManager* handshake_mgr) GRPC_ABSTRACT;
const grpc_server_credentials* server_creds() const { const grpc_server_credentials* server_creds() const {
return server_creds_.get(); return server_creds_.get();

@ -116,7 +116,8 @@ class grpc_ssl_channel_security_connector final
return GRPC_SECURITY_OK; return GRPC_SECURITY_OK;
} }
void add_handshakers(grpc_pollset_set* interested_parties, void add_handshakers(const grpc_channel_args* args,
grpc_pollset_set* interested_parties,
grpc_core::HandshakeManager* handshake_mgr) override { grpc_core::HandshakeManager* handshake_mgr) override {
// Instantiate TSI handshaker. // Instantiate TSI handshaker.
tsi_handshaker* tsi_hs = nullptr; tsi_handshaker* tsi_hs = nullptr;
@ -131,7 +132,7 @@ class grpc_ssl_channel_security_connector final
return; return;
} }
// Create handshakers. // Create handshakers.
handshake_mgr->Add(grpc_core::SecurityHandshakerCreate(tsi_hs, this)); handshake_mgr->Add(grpc_core::SecurityHandshakerCreate(tsi_hs, this, args));
} }
void check_peer(tsi_peer peer, grpc_endpoint* ep, void check_peer(tsi_peer peer, grpc_endpoint* ep,
@ -278,7 +279,8 @@ class grpc_ssl_server_security_connector
return GRPC_SECURITY_OK; return GRPC_SECURITY_OK;
} }
void add_handshakers(grpc_pollset_set* interested_parties, void add_handshakers(const grpc_channel_args* args,
grpc_pollset_set* interested_parties,
grpc_core::HandshakeManager* handshake_mgr) override { grpc_core::HandshakeManager* handshake_mgr) override {
// Instantiate TSI handshaker. // Instantiate TSI handshaker.
try_fetch_ssl_server_credentials(); try_fetch_ssl_server_credentials();
@ -291,7 +293,7 @@ class grpc_ssl_server_security_connector
return; return;
} }
// Create handshakers. // Create handshakers.
handshake_mgr->Add(grpc_core::SecurityHandshakerCreate(tsi_hs, this)); handshake_mgr->Add(grpc_core::SecurityHandshakerCreate(tsi_hs, this, args));
} }
void check_peer(tsi_peer peer, grpc_endpoint* ep, void check_peer(tsi_peer peer, grpc_endpoint* ep,

@ -138,7 +138,7 @@ SpiffeChannelSecurityConnector::~SpiffeChannelSecurityConnector() {
} }
void SpiffeChannelSecurityConnector::add_handshakers( void SpiffeChannelSecurityConnector::add_handshakers(
grpc_pollset_set* interested_parties, const grpc_channel_args* args, grpc_pollset_set* interested_parties,
grpc_core::HandshakeManager* handshake_mgr) { grpc_core::HandshakeManager* handshake_mgr) {
if (RefreshHandshakerFactory() != GRPC_SECURITY_OK) { if (RefreshHandshakerFactory() != GRPC_SECURITY_OK) {
gpr_log(GPR_ERROR, "Handshaker factory refresh failed."); gpr_log(GPR_ERROR, "Handshaker factory refresh failed.");
@ -157,7 +157,7 @@ void SpiffeChannelSecurityConnector::add_handshakers(
return; return;
} }
// Create handshakers. // Create handshakers.
handshake_mgr->Add(grpc_core::SecurityHandshakerCreate(tsi_hs, this)); handshake_mgr->Add(grpc_core::SecurityHandshakerCreate(tsi_hs, this, args));
} }
void SpiffeChannelSecurityConnector::check_peer( void SpiffeChannelSecurityConnector::check_peer(
@ -412,7 +412,7 @@ SpiffeServerSecurityConnector::~SpiffeServerSecurityConnector() {
} }
void SpiffeServerSecurityConnector::add_handshakers( void SpiffeServerSecurityConnector::add_handshakers(
grpc_pollset_set* interested_parties, const grpc_channel_args* args, grpc_pollset_set* interested_parties,
grpc_core::HandshakeManager* handshake_mgr) { grpc_core::HandshakeManager* handshake_mgr) {
/* Refresh handshaker factory if needed. */ /* Refresh handshaker factory if needed. */
if (RefreshHandshakerFactory() != GRPC_SECURITY_OK) { if (RefreshHandshakerFactory() != GRPC_SECURITY_OK) {
@ -428,7 +428,7 @@ void SpiffeServerSecurityConnector::add_handshakers(
tsi_result_to_string(result)); tsi_result_to_string(result));
return; return;
} }
handshake_mgr->Add(grpc_core::SecurityHandshakerCreate(tsi_hs, this)); handshake_mgr->Add(grpc_core::SecurityHandshakerCreate(tsi_hs, this, args));
} }
void SpiffeServerSecurityConnector::check_peer( void SpiffeServerSecurityConnector::check_peer(

@ -47,7 +47,8 @@ class SpiffeChannelSecurityConnector final
const char* target_name, const char* overridden_target_name); const char* target_name, const char* overridden_target_name);
~SpiffeChannelSecurityConnector() override; ~SpiffeChannelSecurityConnector() override;
void add_handshakers(grpc_pollset_set* interested_parties, void add_handshakers(const grpc_channel_args* args,
grpc_pollset_set* interested_parties,
grpc_core::HandshakeManager* handshake_mgr) override; grpc_core::HandshakeManager* handshake_mgr) override;
void check_peer(tsi_peer peer, grpc_endpoint* ep, void check_peer(tsi_peer peer, grpc_endpoint* ep,
@ -117,7 +118,8 @@ class SpiffeServerSecurityConnector final
grpc_core::RefCountedPtr<grpc_server_credentials> server_creds); grpc_core::RefCountedPtr<grpc_server_credentials> server_creds);
~SpiffeServerSecurityConnector() override; ~SpiffeServerSecurityConnector() override;
void add_handshakers(grpc_pollset_set* interested_parties, void add_handshakers(const grpc_channel_args* args,
grpc_pollset_set* interested_parties,
grpc_core::HandshakeManager* handshake_mgr) override; grpc_core::HandshakeManager* handshake_mgr) override;
void check_peer(tsi_peer peer, grpc_endpoint* ep, void check_peer(tsi_peer peer, grpc_endpoint* ep,

@ -22,6 +22,7 @@
#include <stdbool.h> #include <stdbool.h>
#include <string.h> #include <string.h>
#include <limits>
#include <grpc/slice_buffer.h> #include <grpc/slice_buffer.h>
#include <grpc/support/alloc.h> #include <grpc/support/alloc.h>
@ -46,7 +47,8 @@ namespace {
class SecurityHandshaker : public Handshaker { class SecurityHandshaker : public Handshaker {
public: public:
SecurityHandshaker(tsi_handshaker* handshaker, SecurityHandshaker(tsi_handshaker* handshaker,
grpc_security_connector* connector); grpc_security_connector* connector,
const grpc_channel_args* args);
~SecurityHandshaker() override; ~SecurityHandshaker() override;
void Shutdown(grpc_error* why) override; void Shutdown(grpc_error* why) override;
void DoHandshake(grpc_tcp_server_acceptor* acceptor, void DoHandshake(grpc_tcp_server_acceptor* acceptor,
@ -97,15 +99,23 @@ class SecurityHandshaker : public Handshaker {
grpc_closure on_peer_checked_; grpc_closure on_peer_checked_;
RefCountedPtr<grpc_auth_context> auth_context_; RefCountedPtr<grpc_auth_context> auth_context_;
tsi_handshaker_result* handshaker_result_ = nullptr; tsi_handshaker_result* handshaker_result_ = nullptr;
size_t max_frame_size_ = 0;
}; };
SecurityHandshaker::SecurityHandshaker(tsi_handshaker* handshaker, SecurityHandshaker::SecurityHandshaker(tsi_handshaker* handshaker,
grpc_security_connector* connector) grpc_security_connector* connector,
const grpc_channel_args* args)
: handshaker_(handshaker), : handshaker_(handshaker),
connector_(connector->Ref(DEBUG_LOCATION, "handshake")), connector_(connector->Ref(DEBUG_LOCATION, "handshake")),
handshake_buffer_size_(GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE), handshake_buffer_size_(GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE),
handshake_buffer_( handshake_buffer_(
static_cast<uint8_t*>(gpr_malloc(handshake_buffer_size_))) { static_cast<uint8_t*>(gpr_malloc(handshake_buffer_size_))) {
const grpc_arg* arg =
grpc_channel_args_find(args, GRPC_ARG_TSI_MAX_FRAME_SIZE);
if (arg != nullptr && arg->type == GRPC_ARG_INTEGER) {
max_frame_size_ = grpc_channel_arg_get_integer(
arg, {0, 0, std::numeric_limits<int>::max()});
}
gpr_mu_init(&mu_); gpr_mu_init(&mu_);
grpc_slice_buffer_init(&outgoing_); grpc_slice_buffer_init(&outgoing_);
GRPC_CLOSURE_INIT(&on_handshake_data_sent_to_peer_, GRPC_CLOSURE_INIT(&on_handshake_data_sent_to_peer_,
@ -201,7 +211,8 @@ void SecurityHandshaker::OnPeerCheckedInner(grpc_error* error) {
// Create zero-copy frame protector, if implemented. // Create zero-copy frame protector, if implemented.
tsi_zero_copy_grpc_protector* zero_copy_protector = nullptr; tsi_zero_copy_grpc_protector* zero_copy_protector = nullptr;
tsi_result result = tsi_handshaker_result_create_zero_copy_grpc_protector( tsi_result result = tsi_handshaker_result_create_zero_copy_grpc_protector(
handshaker_result_, nullptr, &zero_copy_protector); handshaker_result_, max_frame_size_ == 0 ? nullptr : &max_frame_size_,
&zero_copy_protector);
if (result != TSI_OK && result != TSI_UNIMPLEMENTED) { if (result != TSI_OK && result != TSI_UNIMPLEMENTED) {
error = grpc_set_tsi_error_result( error = grpc_set_tsi_error_result(
GRPC_ERROR_CREATE_FROM_STATIC_STRING( GRPC_ERROR_CREATE_FROM_STATIC_STRING(
@ -213,8 +224,9 @@ void SecurityHandshaker::OnPeerCheckedInner(grpc_error* error) {
// Create frame protector if zero-copy frame protector is NULL. // Create frame protector if zero-copy frame protector is NULL.
tsi_frame_protector* protector = nullptr; tsi_frame_protector* protector = nullptr;
if (zero_copy_protector == nullptr) { if (zero_copy_protector == nullptr) {
result = tsi_handshaker_result_create_frame_protector(handshaker_result_, result = tsi_handshaker_result_create_frame_protector(
nullptr, &protector); handshaker_result_, max_frame_size_ == 0 ? nullptr : &max_frame_size_,
&protector);
if (result != TSI_OK) { if (result != TSI_OK) {
error = grpc_set_tsi_error_result(GRPC_ERROR_CREATE_FROM_STATIC_STRING( error = grpc_set_tsi_error_result(GRPC_ERROR_CREATE_FROM_STATIC_STRING(
"Frame protector creation failed"), "Frame protector creation failed"),
@ -459,7 +471,8 @@ class ClientSecurityHandshakerFactory : public HandshakerFactory {
reinterpret_cast<grpc_channel_security_connector*>( reinterpret_cast<grpc_channel_security_connector*>(
grpc_security_connector_find_in_args(args)); grpc_security_connector_find_in_args(args));
if (security_connector) { if (security_connector) {
security_connector->add_handshakers(interested_parties, handshake_mgr); security_connector->add_handshakers(args, interested_parties,
handshake_mgr);
} }
} }
~ClientSecurityHandshakerFactory() override = default; ~ClientSecurityHandshakerFactory() override = default;
@ -474,7 +487,8 @@ class ServerSecurityHandshakerFactory : public HandshakerFactory {
reinterpret_cast<grpc_server_security_connector*>( reinterpret_cast<grpc_server_security_connector*>(
grpc_security_connector_find_in_args(args)); grpc_security_connector_find_in_args(args));
if (security_connector) { if (security_connector) {
security_connector->add_handshakers(interested_parties, handshake_mgr); security_connector->add_handshakers(args, interested_parties,
handshake_mgr);
} }
} }
~ServerSecurityHandshakerFactory() override = default; ~ServerSecurityHandshakerFactory() override = default;
@ -487,13 +501,14 @@ class ServerSecurityHandshakerFactory : public HandshakerFactory {
// //
RefCountedPtr<Handshaker> SecurityHandshakerCreate( RefCountedPtr<Handshaker> SecurityHandshakerCreate(
tsi_handshaker* handshaker, grpc_security_connector* connector) { tsi_handshaker* handshaker, grpc_security_connector* connector,
const grpc_channel_args* args) {
// If no TSI handshaker was created, return a handshaker that always fails. // If no TSI handshaker was created, return a handshaker that always fails.
// Otherwise, return a real security handshaker. // Otherwise, return a real security handshaker.
if (handshaker == nullptr) { if (handshaker == nullptr) {
return MakeRefCounted<FailHandshaker>(); return MakeRefCounted<FailHandshaker>();
} else { } else {
return MakeRefCounted<SecurityHandshaker>(handshaker, connector); return MakeRefCounted<SecurityHandshaker>(handshaker, connector, args);
} }
} }
@ -509,6 +524,7 @@ void SecurityRegisterHandshakerFactories() {
} // namespace grpc_core } // namespace grpc_core
grpc_handshaker* grpc_security_handshaker_create( grpc_handshaker* grpc_security_handshaker_create(
tsi_handshaker* handshaker, grpc_security_connector* connector) { tsi_handshaker* handshaker, grpc_security_connector* connector,
return SecurityHandshakerCreate(handshaker, connector).release(); const grpc_channel_args* args) {
return SecurityHandshakerCreate(handshaker, connector, args).release();
} }

@ -28,7 +28,8 @@ namespace grpc_core {
/// Creates a security handshaker using \a handshaker. /// Creates a security handshaker using \a handshaker.
RefCountedPtr<Handshaker> SecurityHandshakerCreate( RefCountedPtr<Handshaker> SecurityHandshakerCreate(
tsi_handshaker* handshaker, grpc_security_connector* connector); tsi_handshaker* handshaker, grpc_security_connector* connector,
const grpc_channel_args* args);
/// Registers security handshaker factories. /// Registers security handshaker factories.
void SecurityRegisterHandshakerFactories(); void SecurityRegisterHandshakerFactories();
@ -38,6 +39,7 @@ void SecurityRegisterHandshakerFactories();
// TODO(arjunroy): This is transitional to account for the new handshaker API // TODO(arjunroy): This is transitional to account for the new handshaker API
// and will eventually be removed entirely. // and will eventually be removed entirely.
grpc_handshaker* grpc_security_handshaker_create( grpc_handshaker* grpc_security_handshaker_create(
tsi_handshaker* handshaker, grpc_security_connector* connector); tsi_handshaker* handshaker, grpc_security_connector* connector,
const grpc_channel_args* args);
#endif /* GRPC_CORE_LIB_SECURITY_TRANSPORT_SECURITY_HANDSHAKER_H */ #endif /* GRPC_CORE_LIB_SECURITY_TRANSPORT_SECURITY_HANDSHAKER_H */

@ -34,7 +34,7 @@
constexpr size_t kMinFrameLength = 1024; constexpr size_t kMinFrameLength = 1024;
constexpr size_t kDefaultFrameLength = 16 * 1024; constexpr size_t kDefaultFrameLength = 16 * 1024;
constexpr size_t kMaxFrameLength = 1024 * 1024; constexpr size_t kMaxFrameLength = 16 * 1024 * 1024;
// Limit k on number of frames such that at most 2^(8 * k) frames can be sent. // Limit k on number of frames such that at most 2^(8 * k) frames can be sent.
constexpr size_t kAltsRecordProtocolRekeyFrameLimit = 8; constexpr size_t kAltsRecordProtocolRekeyFrameLimit = 8;

Loading…
Cancel
Save