From 62f1470f096168e88519560b2e8faf1e5bff70b0 Mon Sep 17 00:00:00 2001 From: Jie Luo Date: Thu, 27 Apr 2023 11:57:42 -0700 Subject: [PATCH] -Make message CopyFrom() call UPB code instead of default implementation for python UPB -Add upb_Message_Clear() in UPB accessors -Add upb_Message_CopyFrom() in UPB copy -Fix UPB upb_Message_DeepClone() bug for repeated extensions PiperOrigin-RevId: 527644009 --- protos/protos.h | 16 +++---- python/BUILD | 1 + python/message.c | 47 ++++++++++++++------ upb/message/accessors.h | 7 +++ upb/message/copy.c | 92 ++++++++++++++++++++++++---------------- upb/message/copy.h | 4 ++ upb/message/internal.h | 3 -- upb/message/message.c | 6 --- upb/reflection/message.c | 2 +- 9 files changed, 110 insertions(+), 68 deletions(-) diff --git a/protos/protos.h b/protos/protos.h index d3ec430cc5..ca9f2da5d3 100644 --- a/protos/protos.h +++ b/protos/protos.h @@ -360,7 +360,7 @@ absl::StatusOr> GetExtension( template bool Parse(T& message, absl::string_view bytes) { - _upb_Message_Clear(message.msg(), T::minitable()); + upb_Message_Clear(message.msg(), T::minitable()); auto* arena = static_cast(message.GetInternalArena()); return upb_Decode(bytes.data(), bytes.size(), message.msg(), T::minitable(), /* extreg= */ nullptr, /* options= */ 0, @@ -370,7 +370,7 @@ bool Parse(T& message, absl::string_view bytes) { template bool Parse(T& message, absl::string_view bytes, const ::protos::ExtensionRegistry& extension_registry) { - _upb_Message_Clear(message.msg(), T::minitable()); + upb_Message_Clear(message.msg(), T::minitable()); auto* arena = static_cast(message.GetInternalArena()); return upb_Decode(bytes.data(), bytes.size(), message.msg(), T::minitable(), /* extreg= */ @@ -380,7 +380,7 @@ bool Parse(T& message, absl::string_view bytes, template bool Parse(Ptr& message, absl::string_view bytes) { - _upb_Message_Clear(message->msg(), T::minitable()); + upb_Message_Clear(message->msg(), T::minitable()); auto* arena = static_cast(message->GetInternalArena()); return upb_Decode(bytes.data(), bytes.size(), message->msg(), T::minitable(), /* extreg= */ nullptr, /* options= */ 0, @@ -390,7 +390,7 @@ bool Parse(Ptr& message, absl::string_view bytes) { template bool Parse(Ptr& message, absl::string_view bytes, const ::protos::ExtensionRegistry& extension_registry) { - _upb_Message_Clear(message->msg(), T::minitable()); + upb_Message_Clear(message->msg(), T::minitable()); auto* arena = static_cast(message->GetInternalArena()); return upb_Decode(bytes.data(), bytes.size(), message->msg(), T::minitable(), /* extreg= */ @@ -400,7 +400,7 @@ bool Parse(Ptr& message, absl::string_view bytes, template bool Parse(std::unique_ptr& message, absl::string_view bytes) { - _upb_Message_Clear(message->msg(), T::minitable()); + upb_Message_Clear(message->msg(), T::minitable()); auto* arena = static_cast(message->GetInternalArena()); return upb_Decode(bytes.data(), bytes.size(), message->msg(), T::minitable(), /* extreg= */ nullptr, /* options= */ 0, @@ -410,7 +410,7 @@ bool Parse(std::unique_ptr& message, absl::string_view bytes) { template bool Parse(std::unique_ptr& message, absl::string_view bytes, const ::protos::ExtensionRegistry& extension_registry) { - _upb_Message_Clear(message->msg(), T::minitable()); + upb_Message_Clear(message->msg(), T::minitable()); auto* arena = static_cast(message->GetInternalArena()); return upb_Decode(bytes.data(), bytes.size(), message->msg(), T::minitable(), /* extreg= */ @@ -420,7 +420,7 @@ bool Parse(std::unique_ptr& message, absl::string_view bytes, template bool Parse(std::shared_ptr& message, absl::string_view bytes) { - _upb_Message_Clear(message->msg(), T::minitable()); + upb_Message_Clear(message->msg(), T::minitable()); auto* arena = static_cast(message->GetInternalArena()); return upb_Decode(bytes.data(), bytes.size(), message->msg(), T::minitable(), /* extreg= */ nullptr, /* options= */ 0, @@ -430,7 +430,7 @@ bool Parse(std::shared_ptr& message, absl::string_view bytes) { template bool Parse(std::shared_ptr& message, absl::string_view bytes, const ::protos::ExtensionRegistry& extension_registry) { - _upb_Message_Clear(message->msg(), T::minitable()); + upb_Message_Clear(message->msg(), T::minitable()); auto* arena = static_cast(message->GetInternalArena()); return upb_Decode(bytes.data(), bytes.size(), message->msg(), T::minitable(), /* extreg= */ diff --git a/python/BUILD b/python/BUILD index 5f1e23546f..798ecb00e5 100644 --- a/python/BUILD +++ b/python/BUILD @@ -241,6 +241,7 @@ py_extension( "//:descriptor_upb_proto_reflection", "//:eps_copy_input_stream", "//:hash", + "//:message_copy", "//:port", "//:reflection", "//:textformat", diff --git a/python/message.c b/python/message.c index 7c876a7835..cc5bd77dfc 100644 --- a/python/message.c +++ b/python/message.c @@ -32,6 +32,7 @@ #include "python/extension_dict.h" #include "python/map.h" #include "python/repeated.h" +#include "upb/message/copy.h" #include "upb/reflection/def.h" #include "upb/reflection/message.h" #include "upb/text/encode.h" @@ -325,12 +326,7 @@ static bool PyUpb_Message_LookupName(PyUpb_Message* self, PyObject* py_name, static bool PyUpb_Message_InitMessageMapEntry(PyObject* dst, PyObject* src) { if (!src || !dst) return false; - // TODO(haberman): Currently we are doing Clear()+MergeFrom(). Replace with - // CopyFrom() once that is implemented. - PyObject* ok = PyObject_CallMethod(dst, "Clear", NULL); - if (!ok) return false; - Py_DECREF(ok); - ok = PyObject_CallMethod(dst, "MergeFrom", "O", src); + PyObject* ok = PyObject_CallMethod(dst, "CopyFrom", "O", src); if (!ok) return false; Py_DECREF(ok); @@ -1218,6 +1214,34 @@ static PyObject* PyUpb_Message_MergePartialFrom(PyObject* self, PyObject* arg) { return PyUpb_Message_MergeInternal(self, arg, false); } +static PyObject* PyUpb_Message_Clear(PyUpb_Message* self); + +static PyObject* PyUpb_Message_CopyFrom(PyObject* _self, PyObject* arg) { + if (_self->ob_type != arg->ob_type) { + PyErr_Format(PyExc_TypeError, + "Parameter to CopyFrom() must be instance of same class: " + "expected %S got %S.", + Py_TYPE(_self), Py_TYPE(arg)); + return NULL; + } + if (_self == arg) { + Py_RETURN_NONE; + } + PyUpb_Message* self = (void*)_self; + PyUpb_Message* other = (void*)arg; + PyUpb_Message_EnsureReified(self); + + PyObject* tmp = PyUpb_Message_Clear(self); + Py_DECREF(tmp); + + upb_Message_DeepCopy(self->ptr.msg, other->ptr.msg, + upb_MessageDef_MiniTable(other->def), + PyUpb_Arena_Get(self->arena)); + PyUpb_Message_SyncSubobjs(self); + + Py_RETURN_NONE; +} + static PyObject* PyUpb_Message_SetInParent(PyObject* _self, PyObject* arg) { PyUpb_Message* self = (void*)_self; PyUpb_Message_EnsureReified(self); @@ -1269,10 +1293,8 @@ PyObject* PyUpb_Message_MergeFromString(PyObject* _self, PyObject* arg) { return PyLong_FromSsize_t(size); } -static PyObject* PyUpb_Message_Clear(PyUpb_Message* self, PyObject* args); - static PyObject* PyUpb_Message_ParseFromString(PyObject* self, PyObject* arg) { - PyObject* tmp = PyUpb_Message_Clear((PyUpb_Message*)self, NULL); + PyObject* tmp = PyUpb_Message_Clear((PyUpb_Message*)self); Py_DECREF(tmp); return PyUpb_Message_MergeFromString(self, arg); } @@ -1290,7 +1312,7 @@ static PyObject* PyUpb_Message_ByteSize(PyObject* self, PyObject* args) { return PyLong_FromSize_t(size); } -static PyObject* PyUpb_Message_Clear(PyUpb_Message* self, PyObject* args) { +static PyObject* PyUpb_Message_Clear(PyUpb_Message* self) { PyUpb_Message_EnsureReified(self); const upb_MessageDef* msgdef = _PyUpb_Message_GetMsgdef(self); PyUpb_WeakMap* subobj_map = self->unset_subobj_map; @@ -1620,9 +1642,8 @@ static PyMethodDef PyUpb_Message_Methods[] = { {"ClearExtension", PyUpb_Message_ClearExtension, METH_O, "Clears a message field."}, {"ClearField", PyUpb_Message_ClearField, METH_O, "Clears a message field."}, - // TODO(https://github.com/protocolbuffers/upb/issues/459) - //{ "CopyFrom", (PyCFunction)CopyFrom, METH_O, - // "Copies a protocol message into the current message." }, + {"CopyFrom", PyUpb_Message_CopyFrom, METH_O, + "Copies a protocol message into the current message."}, {"DiscardUnknownFields", (PyCFunction)PyUpb_Message_DiscardUnknownFields, METH_NOARGS, "Discards the unknown fields."}, {"FindInitializationErrors", PyUpb_Message_FindInitializationErrors, diff --git a/upb/message/accessors.h b/upb/message/accessors.h index c23c477f5d..24703cc845 100644 --- a/upb/message/accessors.h +++ b/upb/message/accessors.h @@ -57,6 +57,13 @@ UPB_API_INLINE void upb_Message_ClearField(upb_Message* msg, } } +UPB_API_INLINE void upb_Message_Clear(upb_Message* msg, + const upb_MiniTable* l) { + // Note: Can't use UPB_PTR_AT() here because we are doing pointer subtraction. + char* mem = (char*)msg - sizeof(upb_Message_Internal); + memset(mem, 0, upb_msg_sizeof(l)); +} + UPB_API_INLINE bool upb_Message_HasField(const upb_Message* msg, const upb_MiniTableField* field) { if (upb_MiniTableField_IsExtension(field)) { diff --git a/upb/message/copy.c b/upb/message/copy.c index 1d15d088bb..5b3f37e976 100644 --- a/upb/message/copy.c +++ b/upb/message/copy.c @@ -187,42 +187,36 @@ static bool upb_Clone_ExtensionValue( mini_table_ext->sub.submsg, arena); } -// Deep clones a message using the provided target arena. -// -// Returns NULL on failure. -upb_Message* upb_Message_DeepClone(const upb_Message* message, - const upb_MiniTable* mini_table, - upb_Arena* arena) { - upb_Message* clone = upb_Message_New(mini_table, arena); +upb_Message* _upb_Message_Copy(upb_Message* dst, const upb_Message* src, + const upb_MiniTable* mini_table, + upb_Arena* arena) { upb_StringView empty_string = upb_StringView_FromDataAndSize(NULL, 0); // Only copy message area skipping upb_Message_Internal. - memcpy(clone, message, mini_table->size); + memcpy(dst, src, mini_table->size); for (size_t i = 0; i < mini_table->field_count; ++i) { const upb_MiniTableField* field = &mini_table->fields[i]; if (!upb_IsRepeatedOrMap(field)) { switch (upb_MiniTableField_CType(field)) { case kUpb_CType_Message: { const upb_Message* sub_message = - upb_Message_GetMessage(message, field, NULL); + upb_Message_GetMessage(src, field, NULL); if (sub_message != NULL) { const upb_MiniTable* sub_message_table = upb_MiniTable_GetSubMessageTable(mini_table, field); - upb_Message* cloned_sub_message = + upb_Message* dst_sub_message = upb_Message_DeepClone(sub_message, sub_message_table, arena); - if (cloned_sub_message == NULL) { + if (dst_sub_message == NULL) { return NULL; } - upb_Message_SetMessage(clone, mini_table, field, - cloned_sub_message); + upb_Message_SetMessage(dst, mini_table, field, dst_sub_message); } } break; case kUpb_CType_String: case kUpb_CType_Bytes: { - upb_StringView str = - upb_Message_GetString(message, field, empty_string); + upb_StringView str = upb_Message_GetString(src, field, empty_string); if (str.size != 0) { if (!upb_Message_SetString( - clone, field, upb_Clone_StringView(str, arena), arena)) { + dst, field, upb_Clone_StringView(str, arena), arena)) { return NULL; } } @@ -233,17 +227,16 @@ upb_Message* upb_Message_DeepClone(const upb_Message* message, } } else { if (upb_MessageField_IsMap(field)) { - const upb_Map* map = upb_Message_GetMap(message, field); + const upb_Map* map = upb_Message_GetMap(src, field); if (map != NULL) { - if (!upb_Message_Map_DeepClone(map, mini_table, field, clone, - arena)) { + if (!upb_Message_Map_DeepClone(map, mini_table, field, dst, arena)) { return NULL; } } } else { - const upb_Array* array = upb_Message_GetArray(message, field); + const upb_Array* array = upb_Message_GetArray(src, field); if (array != NULL) { - if (!upb_Message_Array_DeepClone(array, mini_table, field, clone, + if (!upb_Message_Array_DeepClone(array, mini_table, field, dst, arena)) { return NULL; } @@ -253,33 +246,58 @@ upb_Message* upb_Message_DeepClone(const upb_Message* message, } // Clone extensions. size_t ext_count; - const upb_Message_Extension* ext = _upb_Message_Getexts(message, &ext_count); + const upb_Message_Extension* ext = _upb_Message_Getexts(src, &ext_count); for (size_t i = 0; i < ext_count; ++i) { const upb_Message_Extension* msg_ext = &ext[i]; - upb_Message_Extension* cloned_ext = - _upb_Message_GetOrCreateExtension(clone, msg_ext->ext, arena); - if (!cloned_ext) { - return NULL; - } - if (!upb_Clone_ExtensionValue(msg_ext->ext, msg_ext, cloned_ext, arena)) { - return NULL; + const upb_MiniTableField* field = &msg_ext->ext->field; + upb_Message_Extension* dst_ext = + _upb_Message_GetOrCreateExtension(dst, msg_ext->ext, arena); + if (!dst_ext) return NULL; + if (!upb_IsRepeatedOrMap(field)) { + if (!upb_Clone_ExtensionValue(msg_ext->ext, msg_ext, dst_ext, arena)) { + return NULL; + } + } else { + upb_Array* msg_array = (upb_Array*)msg_ext->data.ptr; + UPB_ASSERT(msg_array); + upb_Array* cloned_array = + upb_Array_DeepClone(msg_array, upb_MiniTableField_CType(field), + msg_ext->ext->sub.submsg, arena); + if (!cloned_array) { + return NULL; + } + dst_ext->data.ptr = (void*)cloned_array; } } // Clone unknowns. size_t unknown_size = 0; - const char* ptr = upb_Message_GetUnknown(message, &unknown_size); + const char* ptr = upb_Message_GetUnknown(src, &unknown_size); if (unknown_size != 0) { UPB_ASSERT(ptr); // Make a copy into destination arena. - void* cloned_unknowns = upb_Arena_Malloc(arena, unknown_size); - if (cloned_unknowns == NULL) { - return NULL; - } - memcpy(cloned_unknowns, ptr, unknown_size); - if (!_upb_Message_AddUnknown(clone, cloned_unknowns, unknown_size, arena)) { + void* dst_unknowns = upb_Arena_Malloc(arena, unknown_size); + if (dst_unknowns == NULL) return NULL; + memcpy(dst_unknowns, ptr, unknown_size); + if (!_upb_Message_AddUnknown(dst, dst_unknowns, unknown_size, arena)) { return NULL; } } - return clone; + return dst; +} + +void upb_Message_DeepCopy(upb_Message* dst, const upb_Message* src, + const upb_MiniTable* mini_table, upb_Arena* arena) { + upb_Message_Clear(dst, mini_table); + _upb_Message_Copy(dst, src, mini_table, arena); +} + +// Deep clones a message using the provided target arena. +// +// Returns NULL on failure. +upb_Message* upb_Message_DeepClone(const upb_Message* message, + const upb_MiniTable* mini_table, + upb_Arena* arena) { + upb_Message* clone = upb_Message_New(mini_table, arena); + return _upb_Message_Copy(clone, message, mini_table, arena); } diff --git a/upb/message/copy.h b/upb/message/copy.h index 18508a183a..f760e5dc35 100644 --- a/upb/message/copy.h +++ b/upb/message/copy.h @@ -54,6 +54,10 @@ upb_Map* upb_Map_DeepClone(const upb_Map* map, upb_CType key_type, const upb_MiniTable* map_entry_table, upb_Arena* arena); +// Deep copies the message from src to dst. +void upb_Message_DeepCopy(upb_Message* dst, const upb_Message* src, + const upb_MiniTable* mini_table, upb_Arena* arena); + #ifdef __cplusplus } /* extern "C" */ #endif diff --git a/upb/message/internal.h b/upb/message/internal.h index 8dbe673a6c..0b301a2a60 100644 --- a/upb/message/internal.h +++ b/upb/message/internal.h @@ -115,9 +115,6 @@ UPB_INLINE upb_Message_Internal* upb_Message_Getinternal( return (upb_Message_Internal*)((char*)msg - size); } -// Clears the given message. -void _upb_Message_Clear(upb_Message* msg, const upb_MiniTable* l); - // Discards the unknown fields for this message only. void _upb_Message_DiscardUnknown_shallow(upb_Message* msg); diff --git a/upb/message/message.c b/upb/message/message.c index 50d7bd6b05..b47cddd2df 100644 --- a/upb/message/message.c +++ b/upb/message/message.c @@ -46,12 +46,6 @@ upb_Message* upb_Message_New(const upb_MiniTable* mini_table, return _upb_Message_New(mini_table, arena); } -void _upb_Message_Clear(upb_Message* msg, const upb_MiniTable* l) { - // Note: Can't use UPB_PTR_AT() here because we are doing pointer subtraction. - char* mem = (char*)msg - sizeof(upb_Message_Internal); - memset(mem, 0, upb_msg_sizeof(l)); -} - static bool realloc_internal(upb_Message* msg, size_t need, upb_Arena* arena) { upb_Message_Internal* in = upb_Message_Getinternal(msg); if (!in->internal) { diff --git a/upb/reflection/message.c b/upb/reflection/message.c index dc683c1df0..cb89adf135 100644 --- a/upb/reflection/message.c +++ b/upb/reflection/message.c @@ -121,7 +121,7 @@ void upb_Message_ClearFieldByDef(upb_Message* msg, const upb_FieldDef* f) { } void upb_Message_ClearByDef(upb_Message* msg, const upb_MessageDef* m) { - _upb_Message_Clear(msg, upb_MessageDef_MiniTable(m)); + upb_Message_Clear(msg, upb_MessageDef_MiniTable(m)); } bool upb_Message_Next(const upb_Message* msg, const upb_MessageDef* m,