Nextgen Proto Pythonic API: Add 'in' operator

The “in” operator will be consistent with HasField but a little different with Proto Plus.

The detail behavior of “in” operator in Nextgen for Struct (to be consist with old Struct behavior):
-Raise TypeError if not pass a string
-Check if the key is in the struct.fields

The detail behavior of “in” operator in Nextgen(for other message):
-Raise ValueError if not pass a string
-Raise ValueError if the string is not a field
-For Oneof: Check any field under the oneof is set
-For has-presence field: check if set
-For non-has-presence field (include repeated fields): raise ValueError

PiperOrigin-RevId: 621240977
pull/16362/head
Jie Luo 10 months ago committed by Copybara-Service
parent 3a2cd26c13
commit de8e550e90
  1. 31
      python/google/protobuf/internal/message_test.py
  2. 12
      python/google/protobuf/internal/python_message.py
  3. 3
      python/google/protobuf/internal/well_known_types.py
  4. 9
      python/google/protobuf/internal/well_known_types_test.py
  5. 23
      python/google/protobuf/message.py
  6. 55
      python/google/protobuf/pyext/message.cc
  7. 20
      python/message.c

@ -1314,6 +1314,24 @@ class MessageTest(unittest.TestCase):
self.assertNotEqual(m, ComparesWithFoo())
self.assertNotEqual(ComparesWithFoo(), m)
def testIn(self, message_module):
m = message_module.TestAllTypes()
self.assertNotIn('optional_nested_message', m)
self.assertNotIn('oneof_bytes', m)
self.assertNotIn('oneof_string', m)
with self.assertRaises(ValueError) as e:
'repeated_int32' in m
with self.assertRaises(ValueError) as e:
'repeated_nested_message' in m
with self.assertRaises(ValueError) as e:
1 in m
with self.assertRaises(ValueError) as e:
'not_a_field' in m
test_util.SetAllFields(m)
self.assertIn('optional_nested_message', m)
self.assertIn('oneof_bytes', m)
self.assertNotIn('oneof_string', m)
# Class to test proto2-only features (required, extensions, etc.)
@testing_refleaks.TestCase
@ -1345,6 +1363,9 @@ class Proto2Test(unittest.TestCase):
self.assertTrue(message.HasField('optional_int32'))
self.assertTrue(message.HasField('optional_bool'))
self.assertTrue(message.HasField('optional_nested_message'))
self.assertIn('optional_int32', message)
self.assertIn('optional_bool', message)
self.assertIn('optional_nested_message', message)
# Set the fields to non-default values.
message.optional_int32 = 5
@ -1363,6 +1384,9 @@ class Proto2Test(unittest.TestCase):
self.assertFalse(message.HasField('optional_int32'))
self.assertFalse(message.HasField('optional_bool'))
self.assertFalse(message.HasField('optional_nested_message'))
self.assertNotIn('optional_int32', message)
self.assertNotIn('optional_bool', message)
self.assertNotIn('optional_nested_message', message)
self.assertEqual(0, message.optional_int32)
self.assertEqual(False, message.optional_bool)
self.assertEqual(0, message.optional_nested_message.bb)
@ -1689,6 +1713,12 @@ class Proto3Test(unittest.TestCase):
with self.assertRaises(ValueError):
message.HasField('repeated_nested_message')
# Can not test "in" operator.
with self.assertRaises(ValueError):
'repeated_int32' in message
with self.assertRaises(ValueError):
'repeated_nested_message' in message
# Fields should default to their type-specific default.
self.assertEqual(0, message.optional_int32)
self.assertEqual(0, message.optional_float)
@ -1699,6 +1729,7 @@ class Proto3Test(unittest.TestCase):
# Setting a submessage should still return proper presence information.
message.optional_nested_message.bb = 0
self.assertTrue(message.HasField('optional_nested_message'))
self.assertIn('optional_nested_message', message)
# Set the fields to non-default values.
message.optional_int32 = 5

@ -1000,6 +1000,17 @@ def _AddUnicodeMethod(unused_message_descriptor, cls):
cls.__unicode__ = __unicode__
def _AddContainsMethod(message_descriptor, cls):
def __contains__(self, field_or_key):
if (message_descriptor.full_name == 'google.protobuf.Struct'):
return field_or_key in self.fields
else:
return self.HasField(field_or_key)
cls.__contains__ = __contains__
def _BytesForNonRepeatedElement(value, field_number, field_type):
"""Returns the number of bytes needed to serialize a non-repeated element.
The returned byte count includes space for tag information and any
@ -1394,6 +1405,7 @@ def _AddMessageMethods(message_descriptor, cls):
_AddStrMethod(message_descriptor, cls)
_AddReprMethod(message_descriptor, cls)
_AddUnicodeMethod(message_descriptor, cls)
_AddContainsMethod(message_descriptor, cls)
_AddByteSizeMethod(message_descriptor, cls)
_AddSerializeToStringMethod(message_descriptor, cls)
_AddSerializePartialToStringMethod(message_descriptor, cls)

@ -497,9 +497,6 @@ class Struct(object):
def __getitem__(self, key):
return _GetStructValue(self.fields[key])
def __contains__(self, item):
return item in self.fields
def __setitem__(self, key, value):
_SetStructValue(self.fields[key], value)

@ -569,6 +569,15 @@ class StructTest(unittest.TestCase):
self.assertEqual([6, True, False, None, inner_struct],
list(struct['key5'].items()))
def testInOperator(self):
struct = struct_pb2.Struct()
struct['key'] = 5
self.assertIn('key', struct)
self.assertNotIn('fields', struct)
with self.assertRaises(TypeError) as e:
1 in struct
def testStructAssignment(self):
# Tests struct assignment from another struct
s1 = struct_pb2.Struct()

@ -75,6 +75,29 @@ class Message(object):
"""Outputs a human-readable representation of the message."""
raise NotImplementedError
def __contains__(self, field_name):
"""Checks if a certain field is set for the message.
Has presence fields return true if the field is set, false if the field is
not set. Fields without presence do raise `ValueError` (this includes
repeated fields, map fields, and implicit presence fields).
If field_name is not defined in the message descriptor, `ValueError` will
be raised.
Note: WKT Struct checks if the key is contained in fields.
Args:
field_name (str): The name of the field to check for presence.
Returns:
bool: Whether a value has been set for the named field.
Raises:
ValueError: if the `field_name` is not a member of this message or
`field_name` is not a string.
"""
raise NotImplementedError
def MergeFrom(self, other_msg):
"""Merges the contents of the specified message into current message.

@ -10,6 +10,7 @@
#include "google/protobuf/pyext/message.h"
#include <Python.h>
#include <structmember.h> // A Python header file.
#include <cstdint>
@ -36,6 +37,7 @@
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/strtod.h"
#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
#include "google/protobuf/map_field.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
#include "google/protobuf/unknown_field_set.h"
@ -85,6 +87,12 @@ class MessageReflectionFriend {
return reflection->IsLazyField(field) ||
reflection->IsLazyExtension(message, field);
}
static bool ContainsMapKey(const Reflection* reflection,
const Message& message,
const FieldDescriptor* field,
const MapKey& map_key) {
return reflection->ContainsMapKey(message, field, map_key);
}
};
static PyObject* kDESCRIPTOR;
@ -1293,11 +1301,16 @@ PyObject* HasField(CMessage* self, PyObject* arg) {
char* field_name;
Py_ssize_t size;
field_name = const_cast<char*>(PyUnicode_AsUTF8AndSize(arg, &size));
Message* message = self->message;
if (!field_name) {
PyErr_Format(PyExc_ValueError,
"The field name passed to message %s"
" is not a str.",
message->GetDescriptor()->name().c_str());
return nullptr;
}
Message* message = self->message;
bool is_in_oneof;
const FieldDescriptor* field_descriptor = FindFieldWithOneofs(
message, absl::string_view(field_name, size), &is_in_oneof);
@ -2290,6 +2303,44 @@ PyObject* ToUnicode(CMessage* self) {
return decoded;
}
PyObject* Contains(CMessage* self, PyObject* arg) {
Message* message = self->message;
const Descriptor* descriptor = message->GetDescriptor();
// For WKT Struct, check if the key is in the fields.
if (descriptor->full_name() == "google.protobuf.Struct") {
const Reflection* reflection = message->GetReflection();
const FieldDescriptor* map_field = descriptor->FindFieldByName("fields");
const FieldDescriptor* key_field = map_field->message_type()->map_key();
PyObject* py_string = CheckString(arg, key_field);
if (!py_string) {
PyErr_SetString(PyExc_TypeError,
"The key passed to Struct message must be a str.");
return nullptr;
}
char* value;
Py_ssize_t value_len;
if (PyBytes_AsStringAndSize(py_string, &value, &value_len) < 0) {
Py_DECREF(py_string);
Py_RETURN_FALSE;
}
std::string key_str;
key_str.assign(value, value_len);
Py_DECREF(py_string);
MapKey map_key;
map_key.SetStringValue(key_str);
if (MessageReflectionFriend::ContainsMapKey(reflection, *message, map_field,
map_key)) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
}
// For other messages, check with HasField.
return HasField(self, arg);
}
// CMessage static methods:
PyObject* _CheckCalledFromGeneratedFile(PyObject* unused,
PyObject* unused_arg) {
@ -2338,6 +2389,8 @@ static PyMethodDef Methods[] = {
"Makes a deep copy of the class."},
{"__unicode__", (PyCFunction)ToUnicode, METH_NOARGS,
"Outputs a unicode representation of the message."},
{"__contains__", (PyCFunction)Contains, METH_O,
"Checks if a message field is set."},
{"ByteSize", (PyCFunction)ByteSize, METH_NOARGS,
"Returns the size of the message in bytes."},
{"Clear", (PyCFunction)Clear, METH_NOARGS, "Clears the message."},

@ -1042,6 +1042,24 @@ static PyObject* PyUpb_Message_HasField(PyObject* _self, PyObject* arg) {
NULL);
}
static PyObject* PyUpb_Message_Contains(PyObject* _self, PyObject* arg) {
const upb_MessageDef* msgdef = PyUpb_Message_GetMsgdef(_self);
// For WKT Struct, check if the key is in the fields.
if (strcmp(upb_MessageDef_FullName(msgdef), "google.protobuf.Struct") == 0) {
PyUpb_Message* self = (void*)_self;
upb_Message* msg = PyUpb_Message_GetMsg(self);
const upb_FieldDef* f = upb_MessageDef_FindFieldByName(msgdef, "fields");
const upb_Map* map = upb_Message_GetFieldByDef(msg, f).map_val;
const upb_MessageDef* entry_m = upb_FieldDef_MessageSubDef(f);
const upb_FieldDef* key_f = upb_MessageDef_Field(entry_m, 0);
upb_MessageValue u_key;
if (!PyUpb_PyToUpb(arg, key_f, &u_key, NULL)) return NULL;
return PyBool_FromLong(upb_Map_Get(map, u_key, NULL));
}
// For other messages, check with HasField.
return PyUpb_Message_HasField(_self, arg);
}
static PyObject* PyUpb_Message_FindInitializationErrors(PyObject* _self,
PyObject* arg);
@ -1640,6 +1658,8 @@ static PyMethodDef PyUpb_Message_Methods[] = {
// TODO
//{ "__unicode__", (PyCFunction)ToUnicode, METH_NOARGS,
// "Outputs a unicode representation of the message." },
{"__contains__", PyUpb_Message_Contains, METH_O,
"Checks if a message field is set."},
{"ByteSize", (PyCFunction)PyUpb_Message_ByteSize, METH_NOARGS,
"Returns the size of the message in bytes."},
{"Clear", (PyCFunction)PyUpb_Message_Clear, METH_NOARGS,

Loading…
Cancel
Save