[Ruby] Message.decode/encode: Add max_recursion_depth option (#9218)

* Message.decode/encode: Add max_recursion_depth option

This allows increasing the recursing depth from the default of 64, by
setting the "max_recursion_depth" to the desired integer value. This is
useful to encode or decode complex nested protobuf messages that otherwise
error out with a RuntimeError or "Error occurred during parsing".

Fixes #1493

* Address review comments

Co-authored-by: Adam Cozzette <acozzette@google.com>
pull/9486/head
Lukas Fittl 3 years ago committed by GitHub
parent 4ed3941e27
commit fbe6ab2487
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 65
      ruby/ext/google/protobuf_c/message.c
  2. 8
      ruby/lib/google/protobuf.rb
  3. 4
      ruby/lib/google/protobuf/message_exts.rb
  4. 4
      ruby/src/main/java/com/google/protobuf/jruby/RubyMap.java
  5. 78
      ruby/src/main/java/com/google/protobuf/jruby/RubyMessage.java
  6. 51
      ruby/tests/encode_decode_test.rb

@ -953,13 +953,35 @@ static VALUE Message_index_set(VALUE _self, VALUE field_name, VALUE value) {
/* /*
* call-seq: * call-seq:
* MessageClass.decode(data) => message * MessageClass.decode(data, options) => message
* *
* Decodes the given data (as a string containing bytes in protocol buffers wire * Decodes the given data (as a string containing bytes in protocol buffers wire
* format) under the interpretration given by this message class's definition * format) under the interpretration given by this message class's definition
* and returns a message object with the corresponding field values. * and returns a message object with the corresponding field values.
* @param options [Hash] options for the decoder
* max_recursion_depth: set to maximum decoding depth for message (default is 64)
*/ */
static VALUE Message_decode(VALUE klass, VALUE data) { static VALUE Message_decode(int argc, VALUE* argv, VALUE klass) {
VALUE data = argv[0];
int options = 0;
if (argc < 1 || argc > 2) {
rb_raise(rb_eArgError, "Expected 1 or 2 arguments.");
}
if (argc == 2) {
VALUE hash_args = argv[1];
if (TYPE(hash_args) != T_HASH) {
rb_raise(rb_eArgError, "Expected hash arguments.");
}
VALUE depth = rb_hash_lookup(hash_args, ID2SYM(rb_intern("max_recursion_depth")));
if (depth != Qnil && TYPE(depth) == T_FIXNUM) {
options |= UPB_DECODE_MAXDEPTH(FIX2INT(depth));
}
}
if (TYPE(data) != T_STRING) { if (TYPE(data) != T_STRING) {
rb_raise(rb_eArgError, "Expected string for binary protobuf data."); rb_raise(rb_eArgError, "Expected string for binary protobuf data.");
} }
@ -969,7 +991,7 @@ static VALUE Message_decode(VALUE klass, VALUE data) {
upb_DecodeStatus status = upb_Decode( upb_DecodeStatus status = upb_Decode(
RSTRING_PTR(data), RSTRING_LEN(data), (upb_Message*)msg->msg, RSTRING_PTR(data), RSTRING_LEN(data), (upb_Message*)msg->msg,
upb_MessageDef_MiniTable(msg->msgdef), NULL, 0, Arena_get(msg->arena)); upb_MessageDef_MiniTable(msg->msgdef), NULL, options, Arena_get(msg->arena));
if (status != kUpb_DecodeStatus_Ok) { if (status != kUpb_DecodeStatus_Ok) {
rb_raise(cParseError, "Error occurred during parsing"); rb_raise(cParseError, "Error occurred during parsing");
@ -1043,24 +1065,43 @@ static VALUE Message_decode_json(int argc, VALUE* argv, VALUE klass) {
/* /*
* call-seq: * call-seq:
* MessageClass.encode(msg) => bytes * MessageClass.encode(msg, options) => bytes
* *
* Encodes the given message object to its serialized form in protocol buffers * Encodes the given message object to its serialized form in protocol buffers
* wire format. * wire format.
* @param options [Hash] options for the encoder
* max_recursion_depth: set to maximum encoding depth for message (default is 64)
*/ */
static VALUE Message_encode(VALUE klass, VALUE msg_rb) { static VALUE Message_encode(int argc, VALUE* argv, VALUE klass) {
Message* msg = ruby_to_Message(msg_rb); Message* msg = ruby_to_Message(argv[0]);
int options = 0;
const char* data; const char* data;
size_t size; size_t size;
if (CLASS_OF(msg_rb) != klass) { if (CLASS_OF(argv[0]) != klass) {
rb_raise(rb_eArgError, "Message of wrong type."); rb_raise(rb_eArgError, "Message of wrong type.");
} }
upb_Arena* arena = upb_Arena_New(); if (argc < 1 || argc > 2) {
rb_raise(rb_eArgError, "Expected 1 or 2 arguments.");
}
data = upb_Encode(msg->msg, upb_MessageDef_MiniTable(msg->msgdef), 0, arena, if (argc == 2) {
&size); VALUE hash_args = argv[1];
if (TYPE(hash_args) != T_HASH) {
rb_raise(rb_eArgError, "Expected hash arguments.");
}
VALUE depth = rb_hash_lookup(hash_args, ID2SYM(rb_intern("max_recursion_depth")));
if (depth != Qnil && TYPE(depth) == T_FIXNUM) {
options |= UPB_DECODE_MAXDEPTH(FIX2INT(depth));
}
}
upb_Arena *arena = upb_Arena_New();
data = upb_Encode(msg->msg, upb_MessageDef_MiniTable(msg->msgdef),
options, arena, &size);
if (data) { if (data) {
VALUE ret = rb_str_new(data, size); VALUE ret = rb_str_new(data, size);
@ -1186,8 +1227,8 @@ VALUE build_class_from_descriptor(VALUE descriptor) {
rb_define_method(klass, "to_s", Message_inspect, 0); rb_define_method(klass, "to_s", Message_inspect, 0);
rb_define_method(klass, "[]", Message_index, 1); rb_define_method(klass, "[]", Message_index, 1);
rb_define_method(klass, "[]=", Message_index_set, 2); rb_define_method(klass, "[]=", Message_index_set, 2);
rb_define_singleton_method(klass, "decode", Message_decode, 1); rb_define_singleton_method(klass, "decode", Message_decode, -1);
rb_define_singleton_method(klass, "encode", Message_encode, 1); rb_define_singleton_method(klass, "encode", Message_encode, -1);
rb_define_singleton_method(klass, "decode_json", Message_decode_json, -1); rb_define_singleton_method(klass, "decode_json", Message_decode_json, -1);
rb_define_singleton_method(klass, "encode_json", Message_encode_json, -1); rb_define_singleton_method(klass, "encode_json", Message_encode_json, -1);
rb_define_singleton_method(klass, "descriptor", Message_descriptor, 0); rb_define_singleton_method(klass, "descriptor", Message_descriptor, 0);

@ -59,16 +59,16 @@ require 'google/protobuf/repeated_field'
module Google module Google
module Protobuf module Protobuf
def self.encode(msg) def self.encode(msg, options = {})
msg.to_proto msg.to_proto(options)
end end
def self.encode_json(msg, options = {}) def self.encode_json(msg, options = {})
msg.to_json(options) msg.to_json(options)
end end
def self.decode(klass, proto) def self.decode(klass, proto, options = {})
klass.decode(proto) klass.decode(proto, options)
end end
def self.decode_json(klass, json, options = {}) def self.decode_json(klass, json, options = {})

@ -44,8 +44,8 @@ module Google
self.class.encode_json(self, options) self.class.encode_json(self, options)
end end
def to_proto def to_proto(options = {})
self.class.encode(self) self.class.encode(self, options)
end end
end end

@ -389,7 +389,7 @@ public class RubyMap extends RubyObject {
return newMap; return newMap;
} }
protected List<DynamicMessage> build(ThreadContext context, RubyDescriptor descriptor, int depth) { protected List<DynamicMessage> build(ThreadContext context, RubyDescriptor descriptor, int depth, int maxRecursionDepth) {
List<DynamicMessage> list = new ArrayList<DynamicMessage>(); List<DynamicMessage> list = new ArrayList<DynamicMessage>();
RubyClass rubyClass = (RubyClass) descriptor.msgclass(context); RubyClass rubyClass = (RubyClass) descriptor.msgclass(context);
FieldDescriptor keyField = descriptor.getField("key"); FieldDescriptor keyField = descriptor.getField("key");
@ -398,7 +398,7 @@ public class RubyMap extends RubyObject {
RubyMessage mapMessage = (RubyMessage) rubyClass.newInstance(context, Block.NULL_BLOCK); RubyMessage mapMessage = (RubyMessage) rubyClass.newInstance(context, Block.NULL_BLOCK);
mapMessage.setField(context, keyField, key); mapMessage.setField(context, keyField, key);
mapMessage.setField(context, valueField, table.get(key)); mapMessage.setField(context, valueField, table.get(key));
list.add(mapMessage.build(context, depth + 1)); list.add(mapMessage.build(context, depth + 1, maxRecursionDepth));
} }
return list; return list;
} }

@ -39,6 +39,7 @@ import com.google.protobuf.Descriptors.FieldDescriptor;
import com.google.protobuf.Descriptors.FileDescriptor; import com.google.protobuf.Descriptors.FileDescriptor;
import com.google.protobuf.Descriptors.OneofDescriptor; import com.google.protobuf.Descriptors.OneofDescriptor;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import com.google.protobuf.CodedInputStream;
import com.google.protobuf.DynamicMessage; import com.google.protobuf.DynamicMessage;
import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message; import com.google.protobuf.Message;
@ -461,35 +462,63 @@ public class RubyMessage extends RubyObject {
/* /*
* call-seq: * call-seq:
* MessageClass.encode(msg) => bytes * MessageClass.encode(msg, options = {}) => bytes
* *
* Encodes the given message object to its serialized form in protocol buffers * Encodes the given message object to its serialized form in protocol buffers
* wire format. * wire format.
* @param options [Hash] options for the encoder
* max_recursion_depth: set to maximum encoding depth for message (default is 64)
*/ */
@JRubyMethod(meta = true) @JRubyMethod(required = 1, optional = 1, meta = true)
public static IRubyObject encode(ThreadContext context, IRubyObject recv, IRubyObject value) { public static IRubyObject encode(ThreadContext context, IRubyObject recv, IRubyObject[] args) {
if (recv != value.getMetaClass()) { if (recv != args[0].getMetaClass()) {
throw context.runtime.newArgumentError("Tried to encode a " + value.getMetaClass() + " message with " + recv); throw context.runtime.newArgumentError("Tried to encode a " + args[0].getMetaClass() + " message with " + recv);
} }
RubyMessage message = (RubyMessage) value; RubyMessage message = (RubyMessage) args[0];
return context.runtime.newString(new ByteList(message.build(context).toByteArray())); int maxRecursionDepthInt = SINK_MAXIMUM_NESTING;
if (args.length > 1) {
RubyHash options = (RubyHash) args[1];
IRubyObject maxRecursionDepth = options.fastARef(context.runtime.newSymbol("max_recursion_depth"));
if (maxRecursionDepth != null) {
maxRecursionDepthInt = ((RubyNumeric) maxRecursionDepth).getIntValue();
}
}
return context.runtime.newString(new ByteList(message.build(context, 0, maxRecursionDepthInt).toByteArray()));
} }
/* /*
* call-seq: * call-seq:
* MessageClass.decode(data) => message * MessageClass.decode(data, options = {}) => message
* *
* Decodes the given data (as a string containing bytes in protocol buffers wire * Decodes the given data (as a string containing bytes in protocol buffers wire
* format) under the interpretation given by this message class's definition * format) under the interpretation given by this message class's definition
* and returns a message object with the corresponding field values. * and returns a message object with the corresponding field values.
* @param options [Hash] options for the decoder
* max_recursion_depth: set to maximum decoding depth for message (default is 100)
*/ */
@JRubyMethod(meta = true) @JRubyMethod(required = 1, optional = 1, meta = true)
public static IRubyObject decode(ThreadContext context, IRubyObject recv, IRubyObject data) { public static IRubyObject decode(ThreadContext context, IRubyObject recv, IRubyObject[] args) {
IRubyObject data = args[0];
byte[] bin = data.convertToString().getBytes(); byte[] bin = data.convertToString().getBytes();
CodedInputStream input = CodedInputStream.newInstance(bin);
RubyMessage ret = (RubyMessage) ((RubyClass) recv).newInstance(context, Block.NULL_BLOCK); RubyMessage ret = (RubyMessage) ((RubyClass) recv).newInstance(context, Block.NULL_BLOCK);
if (args.length == 2) {
if (!(args[1] instanceof RubyHash)) {
throw context.runtime.newArgumentError("Expected hash arguments.");
}
IRubyObject maxRecursionDepth = ((RubyHash) args[1]).fastARef(context.runtime.newSymbol("max_recursion_depth"));
if (maxRecursionDepth != null) {
input.setRecursionLimit(((RubyNumeric) maxRecursionDepth).getIntValue());
}
}
try { try {
ret.builder.mergeFrom(bin); ret.builder.mergeFrom(input);
} catch (InvalidProtocolBufferException e) { } catch (Exception e) {
throw RaiseException.from(context.runtime, (RubyClass) context.runtime.getClassFromPath("Google::Protobuf::ParseError"), e.getMessage()); throw RaiseException.from(context.runtime, (RubyClass) context.runtime.getClassFromPath("Google::Protobuf::ParseError"), e.getMessage());
} }
@ -541,7 +570,7 @@ public class RubyMessage extends RubyObject {
printer = printer.usingTypeRegistry(JsonFormat.TypeRegistry.newBuilder().add(message.descriptor).build()); printer = printer.usingTypeRegistry(JsonFormat.TypeRegistry.newBuilder().add(message.descriptor).build());
try { try {
result = printer.print(message.build(context)); result = printer.print(message.build(context, 0, SINK_MAXIMUM_NESTING));
} catch (InvalidProtocolBufferException e) { } catch (InvalidProtocolBufferException e) {
throw runtime.newRuntimeError(e.getMessage()); throw runtime.newRuntimeError(e.getMessage());
} catch (IllegalArgumentException e) { } catch (IllegalArgumentException e) {
@ -635,12 +664,8 @@ public class RubyMessage extends RubyObject {
return ret; return ret;
} }
protected DynamicMessage build(ThreadContext context) { protected DynamicMessage build(ThreadContext context, int depth, int maxRecursionDepth) {
return build(context, 0); if (depth >= maxRecursionDepth) {
}
protected DynamicMessage build(ThreadContext context, int depth) {
if (depth > SINK_MAXIMUM_NESTING) {
throw context.runtime.newRuntimeError("Maximum recursion depth exceeded during encoding."); throw context.runtime.newRuntimeError("Maximum recursion depth exceeded during encoding.");
} }
@ -651,7 +676,7 @@ public class RubyMessage extends RubyObject {
if (value instanceof RubyMap) { if (value instanceof RubyMap) {
builder.clearField(fieldDescriptor); builder.clearField(fieldDescriptor);
RubyDescriptor mapDescriptor = (RubyDescriptor) getDescriptorForField(context, fieldDescriptor); RubyDescriptor mapDescriptor = (RubyDescriptor) getDescriptorForField(context, fieldDescriptor);
for (DynamicMessage kv : ((RubyMap) value).build(context, mapDescriptor, depth)) { for (DynamicMessage kv : ((RubyMap) value).build(context, mapDescriptor, depth, maxRecursionDepth)) {
builder.addRepeatedField(fieldDescriptor, kv); builder.addRepeatedField(fieldDescriptor, kv);
} }
@ -660,7 +685,7 @@ public class RubyMessage extends RubyObject {
builder.clearField(fieldDescriptor); builder.clearField(fieldDescriptor);
for (int i = 0; i < repeatedField.size(); i++) { for (int i = 0; i < repeatedField.size(); i++) {
Object item = convert(context, fieldDescriptor, repeatedField.get(i), depth, Object item = convert(context, fieldDescriptor, repeatedField.get(i), depth, maxRecursionDepth,
/*isDefaultValueForBytes*/ false); /*isDefaultValueForBytes*/ false);
builder.addRepeatedField(fieldDescriptor, item); builder.addRepeatedField(fieldDescriptor, item);
} }
@ -682,7 +707,7 @@ public class RubyMessage extends RubyObject {
fieldDescriptor.getFullName().equals("google.protobuf.FieldDescriptorProto.default_value")) { fieldDescriptor.getFullName().equals("google.protobuf.FieldDescriptorProto.default_value")) {
isDefaultStringForBytes = true; isDefaultStringForBytes = true;
} }
builder.setField(fieldDescriptor, convert(context, fieldDescriptor, value, depth, isDefaultStringForBytes)); builder.setField(fieldDescriptor, convert(context, fieldDescriptor, value, depth, maxRecursionDepth, isDefaultStringForBytes));
} }
} }
@ -702,7 +727,7 @@ public class RubyMessage extends RubyObject {
builder.clearField(fieldDescriptor); builder.clearField(fieldDescriptor);
RubyDescriptor mapDescriptor = (RubyDescriptor) getDescriptorForField(context, RubyDescriptor mapDescriptor = (RubyDescriptor) getDescriptorForField(context,
fieldDescriptor); fieldDescriptor);
for (DynamicMessage kv : ((RubyMap) value).build(context, mapDescriptor, depth)) { for (DynamicMessage kv : ((RubyMap) value).build(context, mapDescriptor, depth, maxRecursionDepth)) {
builder.addRepeatedField(fieldDescriptor, kv); builder.addRepeatedField(fieldDescriptor, kv);
} }
} }
@ -814,7 +839,8 @@ public class RubyMessage extends RubyObject {
// convert a ruby object to protobuf type, skip type check since it is checked on the way in // convert a ruby object to protobuf type, skip type check since it is checked on the way in
private Object convert(ThreadContext context, private Object convert(ThreadContext context,
FieldDescriptor fieldDescriptor, FieldDescriptor fieldDescriptor,
IRubyObject value, int depth, boolean isDefaultStringForBytes) { IRubyObject value, int depth, int maxRecursionDepth,
boolean isDefaultStringForBytes) {
Object val = null; Object val = null;
switch (fieldDescriptor.getType()) { switch (fieldDescriptor.getType()) {
case INT32: case INT32:
@ -855,7 +881,7 @@ public class RubyMessage extends RubyObject {
} }
break; break;
case MESSAGE: case MESSAGE:
val = ((RubyMessage) value).build(context, depth + 1); val = ((RubyMessage) value).build(context, depth + 1, maxRecursionDepth);
break; break;
case ENUM: case ENUM:
EnumDescriptor enumDescriptor = fieldDescriptor.getEnumType(); EnumDescriptor enumDescriptor = fieldDescriptor.getEnumType();
@ -1214,7 +1240,7 @@ public class RubyMessage extends RubyObject {
private static final String CONST_SUFFIX = "_const"; private static final String CONST_SUFFIX = "_const";
private static final String HAS_PREFIX = "has_"; private static final String HAS_PREFIX = "has_";
private static final String QUESTION_MARK = "?"; private static final String QUESTION_MARK = "?";
private static final int SINK_MAXIMUM_NESTING = 63; private static final int SINK_MAXIMUM_NESTING = 64;
private Descriptor descriptor; private Descriptor descriptor;
private DynamicMessage.Builder builder; private DynamicMessage.Builder builder;

@ -101,4 +101,55 @@ class EncodeDecodeTest < Test::Unit::TestCase
assert_match json, "{\"CustomJsonName\":42}" assert_match json, "{\"CustomJsonName\":42}"
end end
def test_decode_depth_limit
msg = A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
)
)
)
)
)
)
msg_encoded = A::B::C::TestMessage.encode(msg)
msg_out = A::B::C::TestMessage.decode(msg_encoded)
assert_match msg.to_json, msg_out.to_json
assert_raise Google::Protobuf::ParseError do
A::B::C::TestMessage.decode(msg_encoded, { max_recursion_depth: 4 })
end
msg_out = A::B::C::TestMessage.decode(msg_encoded, { max_recursion_depth: 5 })
assert_match msg.to_json, msg_out.to_json
end
def test_encode_depth_limit
msg = A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
optional_msg: A::B::C::TestMessage.new(
)
)
)
)
)
)
msg_encoded = A::B::C::TestMessage.encode(msg)
msg_out = A::B::C::TestMessage.decode(msg_encoded)
assert_match msg.to_json, msg_out.to_json
assert_raise RuntimeError do
A::B::C::TestMessage.encode(msg, { max_recursion_depth: 5 })
end
msg_encoded = A::B::C::TestMessage.encode(msg, { max_recursion_depth: 6 })
msg_out = A::B::C::TestMessage.decode(msg_encoded)
assert_match msg.to_json, msg_out.to_json
end
end end

Loading…
Cancel
Save