@ -24,6 +24,7 @@
# include <type_traits>
# include "absl/log/check.h"
# include "src/core/lib/promise/for_each.h"
# include "src/core/lib/promise/if.h"
# include "src/core/lib/promise/latch.h"
# include "src/core/lib/promise/map.h"
@ -121,6 +122,104 @@ struct NoInterceptor {};
namespace filters_detail {
// Flow control across pipe stages.
// This ends up being exceedingly subtle - essentially we need to ensure that
// across a series of pipes we have no more than one outstanding message at a
// time - but those pipes are for the most part independent.
// How we achieve this is that this NextMessage object holds both the message
// and a completion token - the last owning NextMessage instance will call
// the on_progress method on the referenced CallState - and at that point that
// CallState will allow the next message to be sent through it.
// Next, the ForEach promise combiner explicitly holds onto the wrapper object
// owning the result (this object) and extracts the message from it, but doesn't
// dispose that instance until the action promise for the ForEach iteration
// completes, ensuring most callers need do nothing special to have the
// flow control work correctly.
template < void ( CallState : : * on_progress ) ( ) >
class NextMessage {
public :
~ NextMessage ( ) {
if ( message_ ! = end_of_stream ( ) & & message_ ! = error ( ) & &
message_ ! = taken ( ) ) {
delete message_ ;
}
if ( call_state_ ! = nullptr ) {
( call_state_ - > * on_progress ) ( ) ;
}
}
NextMessage ( ) = default ;
explicit NextMessage ( Failure ) : message_ ( error ( ) ) , call_state_ ( nullptr ) { }
NextMessage ( MessageHandle message , CallState * call_state ) {
DCHECK_NE ( call_state , nullptr ) ;
DCHECK_NE ( message . get ( ) , nullptr ) ;
DCHECK ( message . get_deleter ( ) . has_freelist ( ) ) ;
message_ = message . release ( ) ;
call_state_ = call_state ;
}
NextMessage ( const NextMessage & other ) = delete ;
NextMessage & operator = ( const NextMessage & other ) = delete ;
NextMessage ( NextMessage & & other ) noexcept
: message_ ( std : : exchange ( other . message_ , taken ( ) ) ) ,
call_state_ ( std : : exchange ( other . call_state_ , nullptr ) ) { }
NextMessage & operator = ( NextMessage & & other ) noexcept {
if ( message_ ! = end_of_stream ( ) & & message_ ! = error ( ) & &
message_ ! = taken ( ) ) {
delete message_ ;
}
if ( call_state_ ! = nullptr ) {
( call_state_ - > * on_progress ) ( ) ;
}
message_ = std : : exchange ( other . message_ , taken ( ) ) ;
call_state_ = std : : exchange ( other . call_state_ , nullptr ) ;
return * this ;
}
bool ok ( ) const {
DCHECK_NE ( message_ , taken ( ) ) ;
return message_ ! = error ( ) ;
}
bool has_value ( ) const {
DCHECK_NE ( message_ , taken ( ) ) ;
DCHECK ( ok ( ) ) ;
return message_ ! = end_of_stream ( ) ;
}
StatusFlag status ( ) const { return StatusFlag ( ok ( ) ) ; }
Message & value ( ) {
DCHECK_NE ( message_ , taken ( ) ) ;
DCHECK ( ok ( ) ) ;
DCHECK ( has_value ( ) ) ;
return * message_ ;
}
MessageHandle TakeValue ( ) {
DCHECK_NE ( message_ , taken ( ) ) ;
DCHECK ( ok ( ) ) ;
DCHECK ( has_value ( ) ) ;
return MessageHandle ( std : : exchange ( message_ , taken ( ) ) ,
Arena : : PooledDeleter ( ) ) ;
}
bool progressed ( ) const { return call_state_ = = nullptr ; }
void Progress ( ) {
DCHECK ( ! progressed ( ) ) ;
( call_state_ - > * on_progress ) ( ) ;
call_state_ = nullptr ;
}
private :
static Message * end_of_stream ( ) { return nullptr ; }
static Message * error ( ) { return reinterpret_cast < Message * > ( 1 ) ; }
static Message * taken ( ) { return reinterpret_cast < Message * > ( 2 ) ; }
Message * message_ = end_of_stream ( ) ;
CallState * call_state_ = nullptr ;
} ;
template < typename T >
struct ArgumentMustBeNextMessage ;
template < void ( CallState : : * on_progress ) ( ) >
struct ArgumentMustBeNextMessage < NextMessage < on_progress > > {
static constexpr bool value ( ) { return true ; }
} ;
inline void * Offset ( void * base , size_t amt ) {
return static_cast < char * > ( base ) + amt ;
}
@ -1301,6 +1400,80 @@ const NoInterceptor ClientInitialMetadataInterceptor<Fn>::Call::OnFinalize;
} // namespace filters_detail
namespace for_each_detail {
template < void ( CallState : : * on_progress ) ( ) >
struct NextValueTraits < filters_detail : : NextMessage < on_progress > > {
using NextMsg = filters_detail : : NextMessage < on_progress > ;
using Value = MessageHandle ;
GPR_ATTRIBUTE_ALWAYS_INLINE_FUNCTION static NextValueType Type (
const NextMsg & t ) {
if ( ! t . ok ( ) ) return NextValueType : : kError ;
if ( t . has_value ( ) ) return NextValueType : : kValue ;
return NextValueType : : kEndOfStream ;
}
GPR_ATTRIBUTE_ALWAYS_INLINE_FUNCTION static MessageHandle TakeValue (
NextMsg & t ) {
return t . TakeValue ( ) ;
}
} ;
} // namespace for_each_detail
template < void ( CallState : : * on_progress ) ( ) >
struct FailureStatusCastImpl < filters_detail : : NextMessage < on_progress > ,
StatusFlag > {
GPR_ATTRIBUTE_ALWAYS_INLINE_FUNCTION static filters_detail : : NextMessage <
on_progress >
Cast ( StatusFlag flag ) {
DCHECK_EQ ( flag , Failure { } ) ;
return filters_detail : : NextMessage < on_progress > ( Failure { } ) ;
}
} ;
namespace promise_detail {
template < void ( CallState : : * on_progress ) ( ) >
struct TrySeqTraitsWithSfinae < filters_detail : : NextMessage < on_progress > > {
using UnwrappedType = MessageHandle ;
using WrappedType = filters_detail : : NextMessage < on_progress > ;
template < typename Next >
GPR_ATTRIBUTE_ALWAYS_INLINE_FUNCTION static auto CallFactory (
Next * next , WrappedType & & value ) {
return next - > Make ( value . TakeValue ( ) ) ;
}
GPR_ATTRIBUTE_ALWAYS_INLINE_FUNCTION static bool IsOk (
const WrappedType & value ) {
return value . ok ( ) ;
}
static const char * ErrorString ( const WrappedType & status ) {
DCHECK ( ! status . ok ( ) ) ;
return " failed " ;
}
template < typename R >
static R ReturnValue ( WrappedType & & status ) {
DCHECK ( ! status . ok ( ) ) ;
return WrappedType ( Failure { } ) ;
}
template < typename F , typename Elem >
GPR_ATTRIBUTE_ALWAYS_INLINE_FUNCTION static auto CallSeqFactory (
F & f , Elem & & elem , WrappedType value )
- > decltype ( f ( std : : forward < Elem > ( elem ) , std : : declval < MessageHandle > ( ) ) ) {
return f ( std : : forward < Elem > ( elem ) , value . TakeValue ( ) ) ;
}
template < typename Result , typename RunNext >
GPR_ATTRIBUTE_ALWAYS_INLINE_FUNCTION static Poll < Result >
CheckResultAndRunNext ( WrappedType prior , RunNext run_next ) {
if ( ! prior . ok ( ) ) return WrappedType ( prior . status ( ) ) ;
return run_next ( std : : move ( prior ) ) ;
}
} ;
} // namespace promise_detail
using ServerToClientNextMessage =
filters_detail : : NextMessage < & CallState : : FinishPullServerToClientMessage > ;
using ClientToServerNextMessage =
filters_detail : : NextMessage < & CallState : : FinishPullClientToServerMessage > ;
// Execution environment for a stack of filters.
// This is a per-call object.
class CallFilters {
@ -1415,10 +1588,10 @@ class CallFilters {
Input ( CallFilters : : * input_location ) ,
filters_detail : : Layout < Input > ( filters_detail : : StackData : : * layout ) ,
void ( CallState : : * on_done ) ( ) , typename StackIterator >
class Executor {
class Metadata Executor {
public :
Executor ( CallFilters * filters , StackIterator stack_begin ,
StackIterator stack_end )
Metadata Executor( CallFilters * filters , StackIterator stack_begin ,
StackIterator stack_end )
: stack_current_ ( stack_begin ) ,
stack_end_ ( stack_end ) ,
filters_ ( filters ) {
@ -1466,17 +1639,72 @@ class CallFilters {
filters_detail : : OperationExecutor < Input > executor_ ;
} ;
template < MessageHandle ( CallFilters : : * input_location ) ,
filters_detail : : Layout < MessageHandle > (
filters_detail : : StackData : : * layout ) ,
void ( CallState : : * on_done ) ( ) , typename StackIterator >
class MessageExecutor {
public :
using NextMsg = filters_detail : : NextMessage < on_done > ;
MessageExecutor ( CallFilters * filters , StackIterator stack_begin ,
StackIterator stack_end )
: stack_current_ ( stack_begin ) ,
stack_end_ ( stack_end ) ,
filters_ ( filters ) {
DCHECK_NE ( ( filters_ - > * input_location ) . get ( ) , nullptr ) ;
}
Poll < NextMsg > operator ( ) ( ) {
if ( ( filters_ - > * input_location ) ! = nullptr ) {
if ( stack_current_ = = stack_end_ ) {
DCHECK_NE ( ( filters_ - > * input_location ) . get ( ) , nullptr ) ;
return NextMsg ( std : : move ( filters_ - > * input_location ) ,
& filters_ - > call_state_ ) ;
}
return FinishStep ( executor_ . Start (
& ( stack_current_ - > stack - > data_ . * layout ) ,
std : : move ( filters_ - > * input_location ) , filters_ - > call_data_ ) ) ;
} else {
return FinishStep ( executor_ . Step ( filters_ - > call_data_ ) ) ;
}
}
private :
Poll < NextMsg > FinishStep ( Poll < filters_detail : : ResultOr < MessageHandle > > p ) {
auto * r = p . value_if_ready ( ) ;
if ( r = = nullptr ) return Pending { } ;
if ( r - > ok ! = nullptr ) {
+ + stack_current_ ;
if ( stack_current_ = = stack_end_ ) {
return NextMsg { std : : move ( r - > ok ) , & filters_ - > call_state_ } ;
}
return FinishStep (
executor_ . Start ( & ( stack_current_ - > stack - > data_ . * layout ) ,
std : : move ( r - > ok ) , filters_ - > call_data_ ) ) ;
}
( filters_ - > call_state_ . * on_done ) ( ) ;
filters_ - > PushServerTrailingMetadata ( std : : move ( r - > error ) ) ;
return Failure { } ;
}
StackIterator stack_current_ ;
StackIterator stack_end_ ;
CallFilters * filters_ ;
filters_detail : : OperationExecutor < MessageHandle > executor_ ;
} ;
public :
// Client: Fetch client initial metadata
// Returns a promise that resolves to ValueOrFailure<ClientMetadataHandle>
GRPC_MUST_USE_RESULT auto PullClientInitialMetadata ( ) {
call_state_ . BeginPullClientInitialMetadata ( ) ;
return Executor < ClientMetadataHandle , ClientMetadataHandle ,
& CallFilters : : push_client_initial_metadata_ ,
& filters_detail : : StackData : : client_initial_metadata ,
& CallState : : FinishPullClientInitialMetadata ,
StacksVector : : const_iterator > ( this , stacks_ . cbegin ( ) ,
stacks_ . cend ( ) ) ;
return Metadata Executor< ClientMetadataHandle , ClientMetadataHandle ,
& CallFilters : : push_client_initial_metadata_ ,
& filters_detail : : StackData : : client_initial_metadata ,
& CallState : : FinishPullClientInitialMetadata ,
StacksVector : : const_iterator > (
this , stacks_ . cbegin ( ) , stacks_ . cend ( ) ) ;
}
// Server: Push server initial metadata
// Returns a promise that resolves to a StatusFlag indicating success
@ -1496,7 +1724,7 @@ class CallFilters {
has_server_initial_metadata ,
[ this ] ( ) {
return Map (
Executor <
Metadata Executor<
absl : : optional < ServerMetadataHandle > ,
ServerMetadataHandle ,
& CallFilters : : push_server_initial_metadata_ ,
@ -1526,7 +1754,7 @@ class CallFilters {
// Client: Indicate that no more messages will be sent
void FinishClientToServerSends ( ) { call_state_ . ClientToServerHalfClose ( ) ; }
// Server: Fetch client to server message
// Returns a promise that resolves to ValueOrFailure<MessageHandle>
// Returns a promise that resolves to ClientToServerNextMessage
GRPC_MUST_USE_RESULT auto PullClientToServerMessage ( ) {
return TrySeq (
[ this ] ( ) {
@ -1536,16 +1764,15 @@ class CallFilters {
return If (
message_available ,
[ this ] ( ) {
return Executor <
absl : : optional < MessageHandle > , MessageHandle ,
return MessageExecutor <
& CallFilters : : push_client_to_server_message_ ,
& filters_detail : : StackData : : client_to_server_messages ,
& CallState : : FinishPullClientToServerMessage ,
StacksVector : : const_iterator > ( this , stacks_ . cbegin ( ) ,
stacks_ . cend ( ) ) ;
} ,
[ ] ( ) - > ValueOrFailure < absl : : optional < MessageHandle > > {
return absl : : optional < MessageHandle > ( ) ;
[ ] ( ) - > ClientToServerNextMessage {
return ClientToServerNextMessage ( ) ;
} ) ;
} ) ;
}
@ -1557,7 +1784,7 @@ class CallFilters {
return [ this ] ( ) { return call_state_ . PollPushServerToClientMessage ( ) ; } ;
}
// Server: Fetch server to client message
// Returns a promise that resolves to ValueOrFailure<MessageHandle>
// Returns a promise that resolves to ServerToClientNextMessage
GRPC_MUST_USE_RESULT auto PullServerToClientMessage ( ) {
return TrySeq (
[ this ] ( ) {
@ -1567,16 +1794,15 @@ class CallFilters {
return If (
message_available ,
[ this ] ( ) {
return Executor <
absl : : optional < MessageHandle > , MessageHandle ,
return MessageExecutor <
& CallFilters : : push_server_to_client_message_ ,
& filters_detail : : StackData : : server_to_client_messages ,
& CallState : : FinishPullServerToClientMessage ,
StacksVector : : const_reverse_iterator > (
this , stacks_ . crbegin ( ) , stacks_ . crend ( ) ) ;
} ,
[ ] ( ) - > ValueOrFailure < absl : : optional < MessageHandle > > {
return absl : : optional < MessageHandle > ( ) ;
[ ] ( ) - > ServerToClientNextMessage {
return ServerToClientNextMessage ( ) ;
} ) ;
} ) ;
}
@ -1654,6 +1880,20 @@ class CallFilters {
static char g_empty_call_data_ ;
} ;
static_assert (
filters_detail : : ArgumentMustBeNextMessage <
absl : : remove_cvref_t < decltype ( std : : declval < CallFilters * > ( )
- > PullServerToClientMessage ( ) ( )
. value ( ) ) > > : : value ( ) ,
" PullServerToClientMessage must return a NextMessage " ) ;
static_assert (
filters_detail : : ArgumentMustBeNextMessage <
absl : : remove_cvref_t < decltype ( std : : declval < CallFilters * > ( )
- > PullClientToServerMessage ( ) ( )
. value ( ) ) > > : : value ( ) ,
" PullServerToClientMessage must return a NextMessage " ) ;
} // namespace grpc_core
# endif // GRPC_SRC_CORE_LIB_TRANSPORT_CALL_FILTERS_H