pull/35537/head
Craig Tiller 1 year ago
parent 9cf228b8a4
commit cc90d93e5f
  1. 7
      src/core/ext/transport/chaotic_good/chaotic_good_transport.h
  2. 3
      src/core/ext/transport/chaotic_good/client_transport.cc
  3. 40
      src/core/ext/transport/chaotic_good/frame.cc
  4. 22
      src/core/ext/transport/chaotic_good/frame.h
  5. 4
      src/core/ext/transport/chaotic_good/frame_header.cc
  6. 2
      src/core/ext/transport/chaotic_good/frame_header.h
  7. 6
      src/core/ext/transport/chaotic_good/server_transport.cc
  8. 14
      test/core/transport/chaotic_good/frame_fuzzer.cc
  9. 4
      test/core/transport/chaotic_good/frame_test.cc

@ -49,7 +49,7 @@ class ChaoticGoodTransport {
// Resolves to StatusOr<tuple<FrameHeader, BufferPair>>.
auto ReadFrameBytes() {
return TrySeq(
control_endpoint_->ReadSlice(FrameHeader::frame_header_size_),
control_endpoint_->ReadSlice(FrameHeader::kFrameHeaderSize),
[this](Slice read_buffer) {
auto frame_header =
FrameHeader::Parse(reinterpret_cast<const uint8_t*>(
@ -88,9 +88,10 @@ class ChaoticGoodTransport {
}
absl::Status DeserializeFrame(FrameHeader header, BufferPair buffers,
Arena* arena, FrameInterface& frame) {
Arena* arena, FrameInterface& frame,
FrameLimits limits) {
return frame.Deserialize(&parser_, header, bitgen_, arena,
std::move(buffers));
std::move(buffers), limits);
}
// Skip a frame, but correctly handle any hpack state updates.

@ -135,7 +135,8 @@ auto ChaoticGoodClientTransport::TransportReadLoop() {
absl::Status deserialize_status;
if (call_handler.has_value()) {
deserialize_status = transport_.DeserializeFrame(
frame_header, std::move(buffers), call_handler->arena(), frame);
frame_header, std::move(buffers), call_handler->arena(), frame,
FrameLimits{1024 * 1024 * 1024, aligned_bytes_ - 1});
} else {
// Stream not found, skip the frame.
transport_.SkipFrame(frame_header, std::move(buffers));

@ -47,8 +47,8 @@ const uint8_t kZeros[64] = {};
namespace {
const NoDestruct<Slice> kZeroSlice{[] {
// Frame header size is fixed to 24 bytes.
auto slice = GRPC_SLICE_MALLOC(FrameHeader::frame_header_size_);
memset(GRPC_SLICE_START_PTR(slice), 0, FrameHeader::frame_header_size_);
auto slice = GRPC_SLICE_MALLOC(FrameHeader::kFrameHeaderSize);
memset(GRPC_SLICE_START_PTR(slice), 0, FrameHeader::kFrameHeaderSize);
return slice;
}()};
@ -81,7 +81,7 @@ class FrameSerializer {
SliceBuffer& AddTrailers() {
header_.flags.set(2);
header_.header_length =
output_.control.Length() - FrameHeader::frame_header_size_;
output_.control.Length() - FrameHeader::kFrameHeaderSize;
return output_.control;
}
@ -91,13 +91,13 @@ class FrameSerializer {
// Header length is already known in AddTrailers().
header_.trailer_length = output_.control.Length() -
header_.header_length -
FrameHeader::frame_header_size_;
FrameHeader::kFrameHeaderSize;
} else {
if (header_.flags.is_set(0)) {
// Calculate frame header length in Finish() since AddTrailers() isn't
// called.
header_.header_length =
output_.control.Length() - FrameHeader::frame_header_size_;
output_.control.Length() - FrameHeader::kFrameHeaderSize;
}
}
header_.Serialize(
@ -175,9 +175,23 @@ absl::StatusOr<Arena::PoolPtr<Metadata>> ReadMetadata(
}
} // namespace
absl::Status FrameLimits::ValidateMessage(const FrameHeader& header) {
if (header.message_length > max_message_size) {
return absl::InvalidArgumentError(
absl::StrCat("Message length ", header.message_length,
" exceeds maximum allowed ", max_message_size));
}
if (header.message_padding > max_padding) {
return absl::InvalidArgumentError(
absl::StrCat("Message padding ", header.message_padding,
" exceeds maximum allowed ", max_padding));
}
return absl::OkStatus();
}
absl::Status SettingsFrame::Deserialize(HPackParser*, const FrameHeader& header,
absl::BitGenRef, Arena*,
BufferPair buffers) {
BufferPair buffers, FrameLimits) {
if (header.type != FrameType::kSettings) {
return absl::InvalidArgumentError("Expected settings frame");
}
@ -201,8 +215,8 @@ std::string SettingsFrame::ToString() const { return "SettingsFrame{}"; }
absl::Status ClientFragmentFrame::Deserialize(HPackParser* parser,
const FrameHeader& header,
absl::BitGenRef bitsrc,
Arena* arena,
BufferPair buffers) {
Arena* arena, BufferPair buffers,
FrameLimits limits) {
if (header.stream_id == 0) {
return absl::InvalidArgumentError("Expected non-zero stream id");
}
@ -224,6 +238,8 @@ absl::Status ClientFragmentFrame::Deserialize(HPackParser* parser,
"Unexpected non-zero header length", header.header_length));
}
if (header.flags.is_set(1)) {
auto r = limits.ValidateMessage(header);
if (!r.ok()) return r;
message =
FragmentMessage{Arena::MakePooled<Message>(std::move(buffers.data), 0),
header.message_padding, header.message_length};
@ -279,8 +295,8 @@ std::string ClientFragmentFrame::ToString() const {
absl::Status ServerFragmentFrame::Deserialize(HPackParser* parser,
const FrameHeader& header,
absl::BitGenRef bitsrc,
Arena* arena,
BufferPair buffers) {
Arena* arena, BufferPair buffers,
FrameLimits limits) {
if (header.stream_id == 0) {
return absl::InvalidArgumentError("Expected non-zero stream id");
}
@ -299,6 +315,8 @@ absl::Status ServerFragmentFrame::Deserialize(HPackParser* parser,
"Unexpected non-zero header length", header.header_length));
}
if (header.flags.is_set(1)) {
auto r = limits.ValidateMessage(header);
if (!r.ok()) return r;
message.emplace(Arena::MakePooled<Message>(std::move(buffers.data), 0),
header.message_padding, header.message_length);
} else if (buffers.data.Length() != 0) {
@ -347,7 +365,7 @@ std::string ServerFragmentFrame::ToString() const {
absl::Status CancelFrame::Deserialize(HPackParser*, const FrameHeader& header,
absl::BitGenRef, Arena*,
BufferPair buffers) {
BufferPair buffers, FrameLimits) {
if (header.type != FrameType::kCancel) {
return absl::InvalidArgumentError("Expected cancel frame");
}

@ -42,12 +42,19 @@ struct BufferPair {
SliceBuffer data;
};
struct FrameLimits {
size_t max_message_size = 1024 * 1024 * 1024;
size_t max_padding = 63;
absl::Status ValidateMessage(const FrameHeader& header);
};
class FrameInterface {
public:
virtual absl::Status Deserialize(HPackParser* parser,
const FrameHeader& header,
absl::BitGenRef bitsrc, Arena* arena,
BufferPair buffers) = 0;
BufferPair buffers, FrameLimits limits) = 0;
virtual BufferPair Serialize(HPackCompressor* encoder) const = 0;
virtual std::string ToString() const = 0;
@ -68,7 +75,7 @@ class FrameInterface {
struct SettingsFrame final : public FrameInterface {
absl::Status Deserialize(HPackParser* parser, const FrameHeader& header,
absl::BitGenRef bitsrc, Arena* arena,
BufferPair buffers) override;
BufferPair buffers, FrameLimits limits) override;
BufferPair Serialize(HPackCompressor* encoder) const override;
std::string ToString() const override;
@ -91,14 +98,17 @@ struct FragmentMessage {
}
bool operator==(const FragmentMessage& other) const {
return EqVal(*message, *other.message) && length == other.length;
if (length != other.length) return false;
if (message == nullptr && other.message == nullptr) return true;
if (message == nullptr || other.message == nullptr) return false;
return EqVal(*message, *other.message);
}
};
struct ClientFragmentFrame final : public FrameInterface {
absl::Status Deserialize(HPackParser* parser, const FrameHeader& header,
absl::BitGenRef bitsrc, Arena* arena,
BufferPair buffers) override;
BufferPair buffers, FrameLimits limits) override;
BufferPair Serialize(HPackCompressor* encoder) const override;
std::string ToString() const override;
@ -116,7 +126,7 @@ struct ClientFragmentFrame final : public FrameInterface {
struct ServerFragmentFrame final : public FrameInterface {
absl::Status Deserialize(HPackParser* parser, const FrameHeader& header,
absl::BitGenRef bitsrc, Arena* arena,
BufferPair buffers) override;
BufferPair buffers, FrameLimits limits) override;
BufferPair Serialize(HPackCompressor* encoder) const override;
std::string ToString() const override;
@ -134,7 +144,7 @@ struct ServerFragmentFrame final : public FrameInterface {
struct CancelFrame final : public FrameInterface {
absl::Status Deserialize(HPackParser* parser, const FrameHeader& header,
absl::BitGenRef bitsrc, Arena* arena,
BufferPair buffers) override;
BufferPair buffers, FrameLimits limits) override;
BufferPair Serialize(HPackCompressor* encoder) const override;
std::string ToString() const override;

@ -70,10 +70,6 @@ absl::StatusOr<FrameHeader> FrameHeader::Parse(const uint8_t* data) {
}
header.message_length = ReadLittleEndianUint32(data + 12);
header.message_padding = ReadLittleEndianUint32(data + 16);
if (header.flags.is_set(1) && header.message_length <= 0) {
return absl::InvalidArgumentError(
absl::StrCat("Invalid message length: ", header.message_length));
}
header.trailer_length = ReadLittleEndianUint32(data + 20);
return header;
}

@ -60,7 +60,7 @@ struct FrameHeader {
trailer_length == h.trailer_length;
}
// Frame header size is fixed to 24 bytes.
static constexpr size_t frame_header_size_ = 24;
static constexpr size_t kFrameHeaderSize = 24;
};
} // namespace chaotic_good

@ -165,7 +165,8 @@ auto ChaoticGoodServerTransport::DeserializeAndPushFragmentToNewCall(
ClientFragmentFrame fragment_frame;
ScopedArenaPtr arena(acceptor_->CreateArena());
absl::Status status = transport_.DeserializeFrame(
frame_header, std::move(buffers), arena.get(), fragment_frame);
frame_header, std::move(buffers), arena.get(), fragment_frame,
FrameLimits{1024 * 1024 * 1024, aligned_bytes_ - 1});
absl::optional<CallInitiator> call_initiator;
if (status.ok()) {
auto create_call_result =
@ -193,7 +194,8 @@ auto ChaoticGoodServerTransport::DeserializeAndPushFragmentToExistingCall(
if (call_initiator.has_value()) arena = call_initiator->arena();
ClientFragmentFrame fragment_frame;
absl::Status status = transport_.DeserializeFrame(
frame_header, std::move(buffers), arena, fragment_frame);
frame_header, std::move(buffers), arena, fragment_frame,
FrameLimits{1024 * 1024 * 1024, aligned_bytes_ - 1});
return MaybePushFragmentIntoCall(std::move(call_initiator), std::move(status),
std::move(fragment_frame));
}

@ -49,6 +49,8 @@ struct DeterministicBitGen : public std::numeric_limits<uint64_t> {
uint64_t operator()() { return 42; }
};
FrameLimits FuzzerFrameLimits() { return FrameLimits{1024 * 1024 * 1024, 63}; }
template <typename T>
void AssertRoundTrips(const T& input, FrameType expected_frame_type) {
HPackCompressor hpack_compressor;
@ -69,9 +71,9 @@ void AssertRoundTrips(const T& input, FrameType expected_frame_type) {
T output;
HPackParser hpack_parser;
DeterministicBitGen bitgen;
auto deser =
output.Deserialize(&hpack_parser, header.value(), absl::BitGenRef(bitgen),
GetContext<Arena>(), std::move(serialized));
auto deser = output.Deserialize(&hpack_parser, header.value(),
absl::BitGenRef(bitgen), GetContext<Arena>(),
std::move(serialized), FuzzerFrameLimits());
GPR_ASSERT(deser.ok());
GPR_ASSERT(output == input);
}
@ -82,9 +84,9 @@ void FinishParseAndChecks(const FrameHeader& header, BufferPair buffers) {
ExecCtx exec_ctx; // Initialized to get this_cpu() info in global_stat().
HPackParser hpack_parser;
DeterministicBitGen bitgen;
auto deser =
parsed.Deserialize(&hpack_parser, header, absl::BitGenRef(bitgen),
GetContext<Arena>(), std::move(buffers));
auto deser = parsed.Deserialize(&hpack_parser, header,
absl::BitGenRef(bitgen), GetContext<Arena>(),
std::move(buffers), FuzzerFrameLimits());
if (!deser.ok()) return;
gpr_log(GPR_INFO, "Read frame: %s", parsed.ToString().c_str());
AssertRoundTrips(parsed, header.type);

@ -28,6 +28,8 @@ namespace grpc_core {
namespace chaotic_good {
namespace {
FrameLimits TestFrameLimits() { return FrameLimits{1024 * 1024 * 1024, 63}; }
template <typename T>
void AssertRoundTrips(const T& input, FrameType expected_frame_type) {
HPackCompressor hpack_compressor;
@ -50,7 +52,7 @@ void AssertRoundTrips(const T& input, FrameType expected_frame_type) {
ScopedArenaPtr arena = MakeScopedArena(1024, &allocator);
auto deser =
output.Deserialize(&hpack_parser, header.value(), absl::BitGenRef(bitgen),
arena.get(), std::move(serialized));
arena.get(), std::move(serialized), TestFrameLimits());
GPR_ASSERT(deser.ok());
GPR_ASSERT(output == input);
}

Loading…
Cancel
Save