diff --git a/src/core/lib/security/security_connector/alts/alts_security_connector.cc b/src/core/lib/security/security_connector/alts/alts_security_connector.cc index c309958fe5b..5c78b7635a3 100644 --- a/src/core/lib/security/security_connector/alts/alts_security_connector.cc +++ b/src/core/lib/security/security_connector/alts/alts_security_connector.cc @@ -210,6 +210,13 @@ grpc_alts_auth_context_from_tsi_peer(const tsi_peer* peer) { gpr_log(GPR_ERROR, "Mismatch of local and peer rpc protocol versions."); return nullptr; } + /* Validate ALTS Context. */ + const tsi_peer_property* alts_context_prop = + tsi_peer_get_property_by_name(peer, TSI_ALTS_CONTEXT); + if (alts_context_prop == nullptr) { + gpr_log(GPR_ERROR, "Missing alts context property."); + return nullptr; + } /* Create auth context. */ auto ctx = grpc_core::MakeRefCounted(nullptr); grpc_auth_context_add_cstring_property( @@ -226,6 +233,12 @@ grpc_alts_auth_context_from_tsi_peer(const tsi_peer* peer) { GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name( ctx.get(), TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY) == 1); } + /* Add alts context to auth context. */ + if (strcmp(tsi_prop->name, TSI_ALTS_CONTEXT) == 0) { + grpc_auth_context_add_property( + ctx.get(), TSI_ALTS_CONTEXT, + tsi_prop->value.data, tsi_prop->value.length); + } } if (!grpc_auth_context_peer_is_authenticated(ctx.get())) { gpr_log(GPR_ERROR, "Invalid unauthenticated peer."); diff --git a/src/core/tsi/alts/handshaker/alts_tsi_handshaker.cc b/src/core/tsi/alts/handshaker/alts_tsi_handshaker.cc index c5383b3a0ba..77758a45acd 100644 --- a/src/core/tsi/alts/handshaker/alts_tsi_handshaker.cc +++ b/src/core/tsi/alts/handshaker/alts_tsi_handshaker.cc @@ -29,6 +29,7 @@ #include #include #include +#include "src/core/ext/upb-generated/src/proto/grpc/gcp/altscontext.upb.h" #include "src/core/lib/gprpp/thd.h" #include "src/core/lib/iomgr/closure.h" @@ -63,6 +64,7 @@ typedef struct alts_tsi_handshaker_result { size_t unused_bytes_size; grpc_slice rpc_versions; bool is_client; + grpc_slice serialized_context; } alts_tsi_handshaker_result; static tsi_result handshaker_result_extract_peer( @@ -74,7 +76,7 @@ static tsi_result handshaker_result_extract_peer( alts_tsi_handshaker_result* result = reinterpret_cast( const_cast(self)); - GPR_ASSERT(kTsiAltsNumOfPeerProperties == 3); + GPR_ASSERT(kTsiAltsNumOfPeerProperties == 4); tsi_result ok = tsi_construct_peer(kTsiAltsNumOfPeerProperties, peer); int index = 0; if (ok != TSI_OK) { @@ -104,7 +106,16 @@ static tsi_result handshaker_result_extract_peer( ok = tsi_construct_string_peer_property( TSI_ALTS_RPC_VERSIONS, reinterpret_cast(GRPC_SLICE_START_PTR(result->rpc_versions)), - GRPC_SLICE_LENGTH(result->rpc_versions), &peer->properties[2]); + GRPC_SLICE_LENGTH(result->rpc_versions), &peer->properties[index]); + if (ok != TSI_OK) { + tsi_peer_destruct(peer); + gpr_log(GPR_ERROR, "Failed to set tsi peer property"); + } + index++; + GPR_ASSERT(&peer->properties[index] != nullptr); + ok = tsi_construct_string_peer_property( + TSI_ALTS_CONTEXT, reinterpret_cast(GRPC_SLICE_START_PTR(result->serialized_context)), GRPC_SLICE_LENGTH(result->serialized_context), + &peer->properties[index]); if (ok != TSI_OK) { tsi_peer_destruct(peer); gpr_log(GPR_ERROR, "Failed to set tsi peer property"); @@ -223,6 +234,27 @@ tsi_result alts_tsi_handshaker_result_create(grpc_gcp_HandshakerResp* resp, gpr_log(GPR_ERROR, "Peer does not set RPC protocol versions."); return TSI_FAILED_PRECONDITION; } + upb_strview application_protocol = grpc_gcp_HandshakerResult_application_protocol(hresult); + if (application_protocol.size == 0) { + gpr_log(GPR_ERROR, "Invalid application protocol"); + return TSI_FAILED_PRECONDITION; + } + upb_strview record_protocol = grpc_gcp_HandshakerResult_record_protocol(hresult); + if (record_protocol.size == 0) { + gpr_log(GPR_ERROR, "Invalid record protocol"); + return TSI_FAILED_PRECONDITION; + } + const grpc_gcp_Identity* local_identity = + grpc_gcp_HandshakerResult_local_identity(hresult); + if (local_identity == nullptr) { + gpr_log(GPR_ERROR, "Invalid local identity"); + return TSI_FAILED_PRECONDITION; + } + upb_strview local_service_account = grpc_gcp_Identity_service_account(local_identity); + if (local_service_account.size == 0) { + gpr_log(GPR_ERROR, "Invalid local service account"); + return TSI_FAILED_PRECONDITION; + } alts_tsi_handshaker_result* result = static_cast(gpr_zalloc(sizeof(*result))); result->key_data = @@ -231,13 +263,29 @@ tsi_result alts_tsi_handshaker_result_create(grpc_gcp_HandshakerResp* resp, result->peer_identity = static_cast(gpr_zalloc(service_account.size + 1)); memcpy(result->peer_identity, service_account.data, service_account.size); - upb::Arena arena; + upb::Arena rpc_protocol_arena; bool serialized = grpc_gcp_rpc_protocol_versions_encode( - peer_rpc_version, arena.ptr(), &result->rpc_versions); + peer_rpc_version, rpc_protocol_arena.ptr(), &result->rpc_versions); if (!serialized) { gpr_log(GPR_ERROR, "Failed to serialize peer's RPC protocol versions."); return TSI_FAILED_PRECONDITION; } + upb::Arena context_arena; + grpc_gcp_AltsContext* context = grpc_gcp_AltsContext_new(context_arena.ptr()); + grpc_gcp_AltsContext_set_application_protocol(context, application_protocol); + grpc_gcp_AltsContext_set_record_protocol(context, record_protocol); + grpc_gcp_AltsContext_set_security_level(context, 2); + grpc_gcp_AltsContext_set_peer_service_account(context, service_account); + grpc_gcp_AltsContext_set_local_service_account(context, local_service_account); + grpc_gcp_AltsContext_set_peer_rpc_versions(context, const_cast(peer_rpc_version)); + size_t serialized_ctx_length; + char* serialized_ctx = + grpc_gcp_AltsContext_serialize(context, context_arena.ptr(), &serialized_ctx_length); + if (serialized_ctx == nullptr) { + gpr_log(GPR_ERROR, "Failed to serialize peer's ALTS context."); + return TSI_FAILED_PRECONDITION; + } + result->serialized_context = grpc_slice_from_copied_buffer(serialized_ctx, serialized_ctx_length); result->is_client = is_client; result->base.vtable = &result_vtable; *self = &result->base; diff --git a/src/core/tsi/alts/handshaker/alts_tsi_handshaker.h b/src/core/tsi/alts/handshaker/alts_tsi_handshaker.h index 6be45d7b447..5bea01a6854 100644 --- a/src/core/tsi/alts/handshaker/alts_tsi_handshaker.h +++ b/src/core/tsi/alts/handshaker/alts_tsi_handshaker.h @@ -30,11 +30,12 @@ #include "src/core/tsi/transport_security_interface.h" #include "src/proto/grpc/gcp/handshaker.upb.h" -#define TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY "service_accont" +#define TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY "service_account" #define TSI_ALTS_CERTIFICATE_TYPE "ALTS" #define TSI_ALTS_RPC_VERSIONS "rpc_versions" +#define TSI_ALTS_CONTEXT "alts_context" -const size_t kTsiAltsNumOfPeerProperties = 3; +const size_t kTsiAltsNumOfPeerProperties = 4; typedef struct alts_tsi_handshaker alts_tsi_handshaker; diff --git a/test/core/security/alts_security_connector_test.cc b/test/core/security/alts_security_connector_test.cc index bcba3408216..e32ef970ef9 100644 --- a/test/core/security/alts_security_connector_test.cc +++ b/test/core/security/alts_security_connector_test.cc @@ -129,13 +129,19 @@ static void test_alts_peer_to_auth_context_success() { grpc_slice serialized_peer_versions; GPR_ASSERT(grpc_gcp_rpc_protocol_versions_encode(&peer_versions, &serialized_peer_versions)); - GPR_ASSERT(tsi_construct_string_peer_property( TSI_ALTS_RPC_VERSIONS, reinterpret_cast( GRPC_SLICE_START_PTR(serialized_peer_versions)), GRPC_SLICE_LENGTH(serialized_peer_versions), &peer.properties[2]) == TSI_OK); + grpc_slice serialized_alts_ctx; + GPR_ASSERT(tsi_construct_string_peer_property( + TSI_ALTS_CONTEXT, + reinterpret_cast( + GRPC_SLICE_START_PTR(serialized_alts_ctx)), + GRPC_SLICE_LENGTH(serialized_alts_ctx), + &peer.properties[3]) == TSI_OK); grpc_core::RefCountedPtr ctx = grpc_alts_auth_context_from_tsi_peer(&peer); GPR_ASSERT(ctx != nullptr); @@ -143,6 +149,7 @@ static void test_alts_peer_to_auth_context_success() { "alice")); ctx.reset(DEBUG_LOCATION, "test"); grpc_slice_unref(serialized_peer_versions); + grpc_slice_unref(serialized_alts_ctx); tsi_peer_destruct(&peer); } diff --git a/test/core/tsi/alts/handshaker/alts_tsi_handshaker_test.cc b/test/core/tsi/alts/handshaker/alts_tsi_handshaker_test.cc index 37731379f22..22e4100c7f6 100644 --- a/test/core/tsi/alts/handshaker/alts_tsi_handshaker_test.cc +++ b/test/core/tsi/alts/handshaker/alts_tsi_handshaker_test.cc @@ -28,6 +28,7 @@ #include "src/core/tsi/alts/handshaker/alts_tsi_handshaker.h" #include "src/core/tsi/alts/handshaker/alts_tsi_handshaker_private.h" #include "test/core/tsi/alts/handshaker/alts_handshaker_service_api_test_lib.h" +#include "src/core/ext/upb-generated/src/proto/grpc/gcp/altscontext.upb.h" #define ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES "Hello World" #define ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME "Hello Google" @@ -42,6 +43,9 @@ #define ALTS_TSI_HANDSHAKER_TEST_MAX_RPC_VERSION_MINOR 2 #define ALTS_TSI_HANDSHAKER_TEST_MIN_RPC_VERSION_MAJOR 2 #define ALTS_TSI_HANDSHAKER_TEST_MIN_RPC_VERSION_MINOR 1 +#define ALTS_TSI_HANDSHAKER_TEST_LOCAL_IDENTITY "chapilocal@service.google.com" +#define ALTS_TSI_HANDSHAKER_TEST_APPLICATION_PROTOCOL "test application protocol" +#define ALTS_TSI_HANDSHAKER_TEST_RECORD_PROTOCOL "test record protocol" using grpc_core::internal::alts_handshaker_client_check_fields_for_testing; using grpc_core::internal::alts_handshaker_client_get_handshaker_for_testing; @@ -117,6 +121,7 @@ static grpc_byte_buffer* generate_handshaker_response( grpc_gcp_HandshakerStatus* status = grpc_gcp_HandshakerResp_mutable_status(resp, arena.ptr()); grpc_gcp_HandshakerStatus_set_code(status, 0); + grpc_gcp_Identity* local_identity; switch (type) { case INVALID: break; @@ -143,6 +148,15 @@ static grpc_byte_buffer* generate_handshaker_response( ALTS_TSI_HANDSHAKER_TEST_MAX_RPC_VERSION_MINOR, ALTS_TSI_HANDSHAKER_TEST_MIN_RPC_VERSION_MAJOR, ALTS_TSI_HANDSHAKER_TEST_MIN_RPC_VERSION_MINOR)); + local_identity = + grpc_gcp_HandshakerResult_mutable_local_identity(result, arena.ptr()); + grpc_gcp_Identity_set_service_account( + local_identity, + upb_strview_makez(ALTS_TSI_HANDSHAKER_TEST_LOCAL_IDENTITY)); + grpc_gcp_HandshakerResult_set_application_protocol( + result, upb_strview_makez(ALTS_TSI_HANDSHAKER_TEST_APPLICATION_PROTOCOL)); + grpc_gcp_HandshakerResult_set_record_protocol( + result, upb_strview_makez(ALTS_TSI_HANDSHAKER_TEST_RECORD_PROTOCOL)); break; case SERVER_NEXT: grpc_gcp_HandshakerResp_set_bytes_consumed( @@ -160,6 +174,15 @@ static grpc_byte_buffer* generate_handshaker_response( ALTS_TSI_HANDSHAKER_TEST_MAX_RPC_VERSION_MINOR, ALTS_TSI_HANDSHAKER_TEST_MIN_RPC_VERSION_MAJOR, ALTS_TSI_HANDSHAKER_TEST_MIN_RPC_VERSION_MINOR)); + local_identity = + grpc_gcp_HandshakerResult_mutable_local_identity(result, arena.ptr()); + grpc_gcp_Identity_set_service_account( + local_identity, + upb_strview_makez(ALTS_TSI_HANDSHAKER_TEST_LOCAL_IDENTITY)); + grpc_gcp_HandshakerResult_set_application_protocol( + result, upb_strview_makez(ALTS_TSI_HANDSHAKER_TEST_APPLICATION_PROTOCOL)); + grpc_gcp_HandshakerResult_set_record_protocol( + result, upb_strview_makez(ALTS_TSI_HANDSHAKER_TEST_RECORD_PROTOCOL)); break; case FAILED: grpc_gcp_HandshakerStatus_set_code(status, 3 /* INVALID ARGUMENT */); @@ -261,6 +284,27 @@ static void on_client_next_success_cb(tsi_result status, void* user_data, GPR_ASSERT(memcmp(ALTS_TSI_HANDSHAKER_TEST_PEER_IDENTITY, peer.properties[1].value.data, peer.properties[1].value.length) == 0); + /* Validate alts context. */ + upb::Arena context_arena; + grpc_gcp_AltsContext* ctx = + grpc_gcp_AltsContext_parse(peer.properties[3].value.data, peer.properties[3].value.length, context_arena.ptr()); + GPR_ASSERT(ctx != nullptr); + upb_strview application_protocol = grpc_gcp_AltsContext_application_protocol(ctx); + upb_strview record_protocol = grpc_gcp_AltsContext_record_protocol(ctx); + upb_strview peer_account = grpc_gcp_AltsContext_peer_service_account(ctx); + upb_strview local_account = grpc_gcp_AltsContext_local_service_account(ctx); + GPR_ASSERT(memcmp(ALTS_TSI_HANDSHAKER_TEST_APPLICATION_PROTOCOL, + application_protocol.data, + application_protocol.size) == 0); + GPR_ASSERT(memcmp(ALTS_TSI_HANDSHAKER_TEST_RECORD_PROTOCOL, + record_protocol.data, + record_protocol.size) == 0); + GPR_ASSERT(memcmp(ALTS_TSI_HANDSHAKER_TEST_PEER_IDENTITY, + peer_account.data, + peer_account.size) == 0); + GPR_ASSERT(memcmp(ALTS_TSI_HANDSHAKER_TEST_LOCAL_IDENTITY, + local_account.data, + local_account.size) == 0); tsi_peer_destruct(&peer); /* Validate unused bytes. */ const unsigned char* bytes = nullptr;