Nextgen Proto Pythonic API: Struct/ListValue assignment and creation

Python dict is now able to be assigned (by create and copy, not reference) and compared with the Protobuf Struct field.
Python list is now able to be assigned (by create and copy, not reference) and compared with the Protobuf ListValue field.

example usage:
  dictionary = {'key1': 5.0, 'key2': {'subkey': 11.0, 'k': False},}
  list_value = [6, 'seven', True, False, None, dictionary]
  msg = more_messages_pb2.WKTMessage(
      optional_struct=dictionary, optional_list_value=list_value
  )
  self.assertEqual(msg.optional_struct, dictionary)
  self.assertEqual(msg.optional_list_value, list_value)

PiperOrigin-RevId: 646099987
pull/17175/head
Jie Luo 8 months ago committed by Copybara-Service
parent 0302c4c438
commit e17821cac1
  1. 6
      python/google/protobuf/internal/descriptor_pool_test.py
  2. 3
      python/google/protobuf/internal/more_messages.proto
  3. 59
      python/google/protobuf/internal/python_message.py
  4. 40
      python/google/protobuf/internal/well_known_types.py
  5. 67
      python/google/protobuf/internal/well_known_types_test.py
  6. 106
      python/google/protobuf/pyext/message.cc
  7. 78
      python/message.c

@ -30,6 +30,7 @@ from google.protobuf.internal import no_package_pb2
from google.protobuf.internal import testing_refleaks
from google.protobuf import duration_pb2
from google.protobuf import struct_pb2
from google.protobuf import timestamp_pb2
from google.protobuf import unittest_features_pb2
from google.protobuf import unittest_import_pb2
@ -439,6 +440,7 @@ class DescriptorPoolTestBase(object):
self.testFindMessageTypeByName()
self.pool.AddSerializedFile(timestamp_pb2.DESCRIPTOR.serialized_pb)
self.pool.AddSerializedFile(duration_pb2.DESCRIPTOR.serialized_pb)
self.pool.AddSerializedFile(struct_pb2.DESCRIPTOR.serialized_pb)
file_json = self.pool.AddSerializedFile(
more_messages_pb2.DESCRIPTOR.serialized_pb)
field = file_json.message_types_by_name['class'].fields_by_name['int_field']
@ -550,6 +552,9 @@ class DescriptorPoolTestBase(object):
timestamp_pb2.DESCRIPTOR.serialized_pb)
duration_desc = descriptor_pb2.FileDescriptorProto.FromString(
duration_pb2.DESCRIPTOR.serialized_pb)
struct_desc = descriptor_pb2.FileDescriptorProto.FromString(
struct_pb2.DESCRIPTOR.serialized_pb
)
more_messages_desc = descriptor_pb2.FileDescriptorProto.FromString(
more_messages_pb2.DESCRIPTOR.serialized_pb)
test1_desc = descriptor_pb2.FileDescriptorProto.FromString(
@ -558,6 +563,7 @@ class DescriptorPoolTestBase(object):
descriptor_pool_test2_pb2.DESCRIPTOR.serialized_pb)
self.pool.Add(timestamp_desc)
self.pool.Add(duration_desc)
self.pool.Add(struct_desc)
self.pool.Add(more_messages_desc)
self.pool.Add(test1_desc)
self.pool.Add(test2_desc)

@ -14,6 +14,7 @@ syntax = "proto2";
package google.protobuf.internal;
import "google/protobuf/duration.proto";
import "google/protobuf/struct.proto";
import "google/protobuf/timestamp.proto";
// A message where tag numbers are listed out of order, to allow us to test our
@ -355,4 +356,6 @@ message ConflictJsonName {
message WKTMessage {
optional Timestamp optional_timestamp = 1;
optional Duration optional_duration = 2;
optional Struct optional_struct = 3;
optional ListValue optional_list_value = 4;
}

@ -51,6 +51,8 @@ from google.protobuf.internal import wire_format
_FieldDescriptor = descriptor_mod.FieldDescriptor
_AnyFullTypeName = 'google.protobuf.Any'
_StructFullTypeName = 'google.protobuf.Struct'
_ListValueFullTypeName = 'google.protobuf.ListValue'
_ExtensionDict = extension_dict._ExtensionDict
class GeneratedProtocolMessageType(type):
@ -515,37 +517,47 @@ def _AddInitMethod(message_descriptor, cls):
# field=None is the same as no field at all.
continue
if field.label == _FieldDescriptor.LABEL_REPEATED:
copy = field._default_constructor(self)
field_copy = field._default_constructor(self)
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite
if _IsMapField(field):
if _IsMessageMapField(field):
for key in field_value:
copy[key].MergeFrom(field_value[key])
field_copy[key].MergeFrom(field_value[key])
else:
copy.update(field_value)
field_copy.update(field_value)
else:
for val in field_value:
if isinstance(val, dict):
copy.add(**val)
field_copy.add(**val)
else:
copy.add().MergeFrom(val)
field_copy.add().MergeFrom(val)
else: # Scalar
if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
field_value = [_GetIntegerEnumValue(field.enum_type, val)
for val in field_value]
copy.extend(field_value)
self._fields[field] = copy
field_copy.extend(field_value)
self._fields[field] = field_copy
elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
copy = field._default_constructor(self)
field_copy = field._default_constructor(self)
new_val = None
if isinstance(field_value, message_mod.Message):
new_val = field_value
elif isinstance(field_value, dict):
new_val = field.message_type._concrete_class(**field_value)
elif field.message_type.full_name == 'google.protobuf.Timestamp':
copy.FromDatetime(field_value)
elif field.message_type.full_name == 'google.protobuf.Duration':
copy.FromTimedelta(field_value)
if field.message_type.full_name == _StructFullTypeName:
field_copy.Clear()
if len(field_value) == 1 and 'fields' in field_value:
try:
field_copy.update(field_value)
except:
# Fall back to init normal message field
field_copy.Clear()
new_val = field.message_type._concrete_class(**field_value)
else:
field_copy.update(field_value)
else:
new_val = field.message_type._concrete_class(**field_value)
elif hasattr(field_copy, '_internal_assign'):
field_copy._internal_assign(field_value)
else:
raise TypeError(
'Message field {0}.{1} must be initialized with a '
@ -558,10 +570,10 @@ def _AddInitMethod(message_descriptor, cls):
if new_val:
try:
copy.MergeFrom(new_val)
field_copy.MergeFrom(new_val)
except TypeError:
_ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
self._fields[field] = copy
self._fields[field] = field_copy
else:
if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
field_value = _GetIntegerEnumValue(field.enum_type, field_value)
@ -777,6 +789,14 @@ def _AddPropertiesForNonRepeatedCompositeField(field, cls):
elif field.message_type.full_name == 'google.protobuf.Duration':
getter(self)
self._fields[field].FromTimedelta(new_value)
elif field.message_type.full_name == _StructFullTypeName:
getter(self)
self._fields[field].Clear()
self._fields[field].update(new_value)
elif field.message_type.full_name == _ListValueFullTypeName:
getter(self)
self._fields[field].Clear()
self._fields[field].extend(new_value)
else:
raise AttributeError(
'Assignment not allowed to composite field '
@ -978,6 +998,15 @@ def _InternalUnpackAny(msg):
def _AddEqualsMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
def __eq__(self, other):
if self.DESCRIPTOR.full_name == _ListValueFullTypeName and isinstance(
other, list
):
return self._internal_compare(other)
if self.DESCRIPTOR.full_name == _StructFullTypeName and isinstance(
other, dict
):
return self._internal_compare(other)
if (not isinstance(other, message_mod.Message) or
other.DESCRIPTOR != self.DESCRIPTOR):
return NotImplemented

@ -283,6 +283,9 @@ class Timestamp(object):
self.seconds = seconds
self.nanos = nanos
def _internal_assign(self, dt):
self.FromDatetime(dt)
def __add__(self, value) -> datetime.datetime:
if isinstance(value, Duration):
return self.ToDatetime() + value.ToTimedelta()
@ -442,6 +445,9 @@ class Duration(object):
'object got {0}: {1}'.format(type(td).__name__, e)
) from e
def _internal_assign(self, td):
self.FromTimedelta(td)
def _NormalizeDuration(self, seconds, nanos):
"""Set Duration by seconds and nanos."""
# Force nanos to be negative if the duration is negative.
@ -550,6 +556,24 @@ class Struct(object):
def __iter__(self):
return iter(self.fields)
def _internal_assign(self, dictionary):
self.Clear()
self.update(dictionary)
def _internal_compare(self, other):
size = len(self)
if size != len(other):
return False
for key, value in self.items():
if key not in other:
return False
if isinstance(other[key], (dict, list)):
if not value._internal_compare(other[key]):
return False
elif value != other[key]:
return False
return True
def keys(self): # pylint: disable=invalid-name
return self.fields.keys()
@ -605,6 +629,22 @@ class ListValue(object):
def __delitem__(self, key):
del self.values[key]
def _internal_assign(self, elem_seq):
self.Clear()
self.extend(elem_seq)
def _internal_compare(self, other):
size = len(self)
if size != len(other):
return False
for i in range(size):
if isinstance(other[i], (dict, list)):
if not self[i]._internal_compare(other[i]):
return False
elif self[i] != other[i]:
return False
return True
def items(self):
for i in range(len(self)):
yield self[i]

@ -838,6 +838,73 @@ class StructTest(unittest.TestCase):
s2['x'] = s1['x']
self.assertEqual(s1['x'], s2['x'])
dictionary = {
'key1': 5.0,
'key2': 'abc',
'key3': {'subkey': 11.0, 'k': False},
}
msg = more_messages_pb2.WKTMessage()
msg.optional_struct = dictionary
self.assertEqual(msg.optional_struct, dictionary)
# Tests assign is not merge
dictionary2 = {
'key4': {'subkey': 11.0, 'k': True},
}
msg.optional_struct = dictionary2
self.assertEqual(msg.optional_struct, dictionary2)
# Tests assign empty
msg2 = more_messages_pb2.WKTMessage()
self.assertNotIn('optional_struct', msg2)
msg2.optional_struct = {}
self.assertIn('optional_struct', msg2)
self.assertEqual(msg2.optional_struct, {})
def testListValueAssignment(self):
list_value = [6, 'seven', True, False, None, {}]
msg = more_messages_pb2.WKTMessage()
msg.optional_list_value = list_value
self.assertEqual(msg.optional_list_value, list_value)
def testStructConstruction(self):
dictionary = {
'key1': 5.0,
'key2': 'abc',
'key3': {'subkey': 11.0, 'k': False},
}
list_value = [6, 'seven', True, False, None, dictionary]
msg = more_messages_pb2.WKTMessage(
optional_struct=dictionary, optional_list_value=list_value
)
self.assertEqual(len(msg.optional_struct), len(dictionary))
self.assertEqual(msg.optional_struct, dictionary)
self.assertEqual(len(msg.optional_list_value), len(list_value))
self.assertEqual(msg.optional_list_value, list_value)
msg2 = more_messages_pb2.WKTMessage(
optional_struct={}, optional_list_value=[]
)
self.assertIn('optional_struct', msg2)
self.assertIn('optional_list_value', msg2)
self.assertEqual(msg2.optional_struct, {})
self.assertEqual(msg2.optional_list_value, [])
def testSpecialStructConstruct(self):
dictionary = {'key1': 6.0}
msg = more_messages_pb2.WKTMessage(optional_struct=dictionary)
self.assertEqual(msg.optional_struct, dictionary)
dictionary2 = {'fields': 7.0}
msg2 = more_messages_pb2.WKTMessage(optional_struct=dictionary2)
self.assertEqual(msg2.optional_struct, dictionary2)
# Construct Struct as normal message
value_msg = struct_pb2.Value(number_value=5.0)
dictionary3 = {'fields': {'key1': value_msg}}
msg3 = more_messages_pb2.WKTMessage(optional_struct=dictionary3)
self.assertEqual(msg3.optional_struct, {'key1': 5.0})
def testMergeFrom(self):
struct = struct_pb2.Struct()
struct_class = struct.__class__

@ -1089,8 +1089,24 @@ int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) {
if (PyDict_Check(value)) {
// Make the message exist even if the dict is empty.
AssureWritable(cmessage);
if (InitAttributes(cmessage, nullptr, value) < 0) {
return -1;
if (descriptor->message_type()->well_known_type() ==
Descriptor::WELLKNOWNTYPE_STRUCT) {
ScopedPyObjectPtr ok(PyObject_CallMethod(
reinterpret_cast<PyObject*>(cmessage), "update", "O", value));
if (ok.get() == nullptr && PyDict_Size(value) == 1 &&
PyDict_Contains(value, PyUnicode_FromString("fields"))) {
// Fallback to init as normal message field.
PyErr_Clear();
PyObject* tmp = Clear(cmessage);
Py_DECREF(tmp);
if (InitAttributes(cmessage, nullptr, value) < 0) {
return -1;
}
}
} else {
if (InitAttributes(cmessage, nullptr, value) < 0) {
return -1;
}
}
} else {
if (PyObject_TypeCheck(value, CMessage_Type)) {
@ -1099,34 +1115,24 @@ int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) {
return -1;
}
} else {
switch (descriptor->message_type()->well_known_type()) {
case Descriptor::WELLKNOWNTYPE_TIMESTAMP: {
AssureWritable(cmessage);
ScopedPyObjectPtr ok(
PyObject_CallMethod(reinterpret_cast<PyObject*>(cmessage),
"FromDatetime", "O", value));
if (ok.get() == nullptr) {
return -1;
}
break;
}
case Descriptor::WELLKNOWNTYPE_DURATION: {
AssureWritable(cmessage);
ScopedPyObjectPtr ok(
PyObject_CallMethod(reinterpret_cast<PyObject*>(cmessage),
"FromTimedelta", "O", value));
if (ok.get() == nullptr) {
return -1;
}
break;
}
default:
PyErr_Format(
PyExc_TypeError,
"Parameter to initialize message field must be "
"dict or instance of same class: expected %s got %s.",
descriptor->full_name().c_str(), Py_TYPE(value)->tp_name);
if (descriptor->message_type()->well_known_type() !=
Descriptor::WELLKNOWNTYPE_UNSPECIFIED &&
PyObject_HasAttrString(reinterpret_cast<PyObject*>(cmessage),
"_internal_assign")) {
AssureWritable(cmessage);
ScopedPyObjectPtr ok(
PyObject_CallMethod(reinterpret_cast<PyObject*>(cmessage),
"_internal_assign", "O", value));
if (ok.get() == nullptr) {
return -1;
}
} else {
PyErr_Format(PyExc_TypeError,
"Parameter to initialize message field must be "
"dict or instance of same class: expected %s got %s.",
descriptor->full_name().c_str(),
Py_TYPE(value)->tp_name);
return -1;
}
}
}
@ -2040,6 +2046,15 @@ static PyObject* RichCompare(CMessage* self, PyObject* other, int opid) {
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
const Descriptor* self_descriptor = self->message->GetDescriptor();
Descriptor::WellKnownType wkt = self_descriptor->well_known_type();
if ((wkt == Descriptor::WELLKNOWNTYPE_LISTVALUE && PyList_Check(other)) ||
(wkt == Descriptor::WELLKNOWNTYPE_STRUCT && PyDict_Check(other))) {
return PyObject_CallMethod(reinterpret_cast<PyObject*>(self),
"_internal_compare", "O", other);
}
// If other is not a message, this implementation doesn't know how to perform
// comparisons.
if (!PyObject_TypeCheck(other, CMessage_Type)) {
@ -2051,8 +2066,7 @@ static PyObject* RichCompare(CMessage* self, PyObject* other, int opid) {
const google::protobuf::Message* other_message =
reinterpret_cast<CMessage*>(other)->message;
// If messages don't have the same descriptors, they are not equal.
if (equals &&
self->message->GetDescriptor() != other_message->GetDescriptor()) {
if (equals && self_descriptor != other_message->GetDescriptor()) {
equals = false;
}
// Check the message contents.
@ -2585,34 +2599,24 @@ int SetFieldValue(CMessage* self, const FieldDescriptor* field_descriptor,
field_descriptor->name().c_str());
return -1;
} else if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
switch (field_descriptor->message_type()->well_known_type()) {
case Descriptor::WELLKNOWNTYPE_TIMESTAMP: {
AssureWritable(self);
PyObject* sub_message = GetFieldValue(self, field_descriptor);
ScopedPyObjectPtr ok(
PyObject_CallMethod(sub_message, "FromDatetime", "O", value));
if (ok.get() == nullptr) {
return -1;
}
return 0;
}
case Descriptor::WELLKNOWNTYPE_DURATION: {
if (field_descriptor->message_type()->well_known_type() !=
Descriptor::WELLKNOWNTYPE_UNSPECIFIED) {
PyObject* sub_message = GetFieldValue(self, field_descriptor);
if (PyObject_HasAttrString(sub_message, "_internal_assign")) {
AssureWritable(self);
PyObject* sub_message = GetFieldValue(self, field_descriptor);
ScopedPyObjectPtr ok(
PyObject_CallMethod(sub_message, "FromTimedelta", "O", value));
PyObject_CallMethod(sub_message, "_internal_assign", "O", value));
if (ok.get() == nullptr) {
return -1;
}
return 0;
}
default:
PyErr_Format(PyExc_AttributeError,
"Assignment not allowed to "
"field \"%s\" in protocol message object.",
field_descriptor->name().c_str());
return -1;
}
PyErr_Format(PyExc_AttributeError,
"Assignment not allowed to "
"field \"%s\" in protocol message object.",
field_descriptor->name().c_str());
return -1;
} else {
AssureWritable(self);
return InternalSetScalar(self, field_descriptor, value);

@ -432,6 +432,8 @@ err:
return ok;
}
static PyObject* PyUpb_Message_Clear(PyUpb_Message* self);
static bool PyUpb_Message_InitMessageAttribute(PyObject* _self, PyObject* name,
const upb_FieldDef* field,
PyObject* value) {
@ -445,25 +447,31 @@ static bool PyUpb_Message_InitMessageAttribute(PyObject* _self, PyObject* name,
Py_XDECREF(tmp);
} else if (PyDict_Check(value)) {
assert(!PyErr_Occurred());
ok = PyUpb_Message_InitAttributes(submsg, NULL, value) >= 0;
} else {
const upb_MessageDef* msgdef = upb_FieldDef_MessageSubDef(field);
switch (upb_MessageDef_WellKnownType(msgdef)) {
case kUpb_WellKnown_Timestamp: {
ok = PyObject_CallMethod(submsg, "FromDatetime", "O", value);
break;
}
case kUpb_WellKnown_Duration: {
ok = PyObject_CallMethod(submsg, "FromTimedelta", "O", value);
break;
}
default: {
const upb_MessageDef* m = PyUpb_Message_GetMsgdef(_self);
PyErr_Format(PyExc_TypeError,
"Message must be initialized with a dict: %s",
upb_MessageDef_FullName(m));
ok = false;
if (upb_MessageDef_WellKnownType(msgdef) == kUpb_WellKnown_Struct) {
ok = PyObject_CallMethod(submsg, "_internal_assign", "O", value);
if (!ok && PyDict_Size(value) == 1 &&
PyDict_Contains(value, PyUnicode_FromString("fields"))) {
// Fall back to init as normal message field.
PyErr_Clear();
PyObject* tmp = PyUpb_Message_Clear((PyUpb_Message*)submsg);
Py_DECREF(tmp);
ok = PyUpb_Message_InitAttributes(submsg, NULL, value) >= 0;
}
} else {
ok = PyUpb_Message_InitAttributes(submsg, NULL, value) >= 0;
}
} else {
const upb_MessageDef* msgdef = upb_FieldDef_MessageSubDef(field);
if (upb_MessageDef_WellKnownType(msgdef) != kUpb_WellKnown_Unspecified &&
PyObject_HasAttrString(submsg, "_internal_assign")) {
ok = PyObject_CallMethod(submsg, "_internal_assign", "O", value);
} else {
const upb_MessageDef* m = PyUpb_Message_GetMsgdef(_self);
PyErr_Format(PyExc_TypeError,
"Message must be initialized with a dict: %s",
upb_MessageDef_FullName(m));
ok = false;
}
}
Py_DECREF(submsg);
@ -772,6 +780,13 @@ static PyObject* PyUpb_Message_RichCompare(PyObject* _self, PyObject* other,
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
const upb_MessageDef* msgdef = _PyUpb_Message_GetMsgdef(self);
upb_WellKnown wkt = upb_MessageDef_WellKnownType(msgdef);
if ((wkt == kUpb_WellKnown_ListValue && PyList_Check(other)) ||
(wkt == kUpb_WellKnown_Struct && PyDict_Check(other))) {
return PyObject_CallMethod(_self, "_internal_compare", "O", other);
}
if (!PyObject_TypeCheck(other, Py_TYPE(self))) {
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
@ -962,30 +977,21 @@ int PyUpb_Message_SetFieldValue(PyObject* _self, const upb_FieldDef* field,
if (upb_FieldDef_IsSubMessage(field)) {
const upb_MessageDef* msgdef = upb_FieldDef_MessageSubDef(field);
switch (upb_MessageDef_WellKnownType(msgdef)) {
case kUpb_WellKnown_Timestamp: {
PyObject* sub_message = PyUpb_Message_GetFieldValue(_self, field);
PyObject* ok =
PyObject_CallMethod(sub_message, "FromDatetime", "O", value);
if (!ok) return -1;
Py_DECREF(ok);
return 0;
}
case kUpb_WellKnown_Duration: {
PyObject* sub_message = PyUpb_Message_GetFieldValue(_self, field);
if (upb_MessageDef_WellKnownType(msgdef) != kUpb_WellKnown_Unspecified) {
PyObject* sub_message = PyUpb_Message_GetFieldValue(_self, field);
if (PyObject_HasAttrString(sub_message, "_internal_assign")) {
PyObject* ok =
PyObject_CallMethod(sub_message, "FromTimedelta", "O", value);
PyObject_CallMethod(sub_message, "_internal_assign", "O", value);
if (!ok) return -1;
Py_DECREF(ok);
return 0;
}
default:
PyErr_Format(exc,
"Assignment not allowed to message "
"field \"%s\" in protocol message object.",
upb_FieldDef_Name(field));
return -1;
}
PyErr_Format(exc,
"Assignment not allowed to message "
"field \"%s\" in protocol message object.",
upb_FieldDef_Name(field));
return -1;
}
upb_MessageValue val;
@ -1272,8 +1278,6 @@ PyObject* PyUpb_Message_MergeFrom(PyObject* self, PyObject* arg) {
Py_RETURN_NONE;
}
static PyObject* PyUpb_Message_Clear(PyUpb_Message* self);
static PyObject* PyUpb_Message_CopyFrom(PyObject* _self, PyObject* arg) {
if (_self->ob_type != arg->ob_type) {
PyErr_Format(PyExc_TypeError,

Loading…
Cancel
Save